Skip to content

Commit aa6c1f9

Browse files
committed
add ideadata2.0
1 parent d50efce commit aa6c1f9

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

ideadata2.0/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .load import load_dataset
2+
__all__ = ['load_dataset']

ideadata2.0/load.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import datasets
2+
from pathlib import Path
3+
from concurrent.futures import ProcessPoolExecutor
4+
5+
6+
_SPLIT_DATA_PATH = '/cognitive_comp/common_data/big_corpus/ideaData/ideaData2.0'
7+
# 缓存文件
8+
_CACHE_SPLIT_DATA_PATH = '/cognitive_comp/common_data/ideadata2.0/'
9+
# feats = datasets.Features({"text": datasets.Value('string')})
10+
11+
12+
class CommonTextCorpusGenerate(object):
13+
"""
14+
处理通用文本数据,输入格式为jsonline,例如每一行都是:{"text":"这是段文本的内容"}
15+
处理成memory map的数据类型,便于后续的加载
16+
"""
17+
def __init__(self,
18+
data_files=_SPLIT_DATA_PATH,
19+
save_path=_CACHE_SPLIT_DATA_PATH,
20+
train_test_validation='9900,70,30',
21+
num_proc=1,
22+
if_shuffle=False,
23+
cache=False):
24+
self.data_files = Path(data_files)
25+
if save_path:
26+
self.save_path = Path(save_path)
27+
else:
28+
self.save_path = self.file_check(
29+
Path(self.data_files.parent, self.data_files.name+'_FSDataset'),'save')
30+
self.num_proc = num_proc
31+
self.cache = cache
32+
self.split_idx = self.split_train_test_validation_index(train_test_validation)
33+
self.shuffle = if_shuffle
34+
if cache:
35+
self.cache_path = self.file_check(
36+
Path(self.save_path.parent, 'FSDataCache', self.data_files.name), 'cache')
37+
else:
38+
self.cache_path = None
39+
40+
@staticmethod
41+
def file_check(path, path_type):
42+
print(path)
43+
if not path.exists():
44+
path.mkdir(parents=True)
45+
print(f"Since no {path_type} directory is specified, the program will automatically create it in {path} directory.")
46+
return str(path)
47+
48+
@staticmethod
49+
def split_train_test_validation_index(train_test_validation):
50+
split_idx_ = [int(i) for i in train_test_validation.split(',')]
51+
idx_dict = {
52+
'train_rate': split_idx_[0]/sum(split_idx_),
53+
'test_rate': split_idx_[1]/sum(split_idx_[1:])
54+
}
55+
return idx_dict
56+
57+
def process(self, index, path):
58+
print('saving dataset shard {}'.format(index))
59+
60+
ds = (datasets.load_dataset('json', data_files=str(path),
61+
cache_dir=self.cache_path,
62+
features=None))
63+
# 局部shuffle
64+
# TODO 全局shuffle
65+
if self.shuffle:
66+
ds = ds.shuffle()
67+
# 添加分句操作 没加
68+
# 这里是拆分了数据集,
69+
ds = ds['train'].train_test_split(train_size=self.split_idx['train_rate'])
70+
ds_ = ds['test'].train_test_split(train_size=self.split_idx['test_rate'])
71+
ds = datasets.DatasetDict({
72+
'train': ds['train'],
73+
'test': ds_['train'],
74+
'validation': ds_['test']
75+
})
76+
77+
ds.save_to_disk(Path(self.save_path, path.name))
78+
return 'saving dataset shard {} done'.format(index)
79+
80+
def generate_cache_arrow(self) -> None:
81+
'''
82+
生成HF支持的缓存文件,加速后续的加载
83+
'''
84+
data_dict_paths = self.data_files.rglob('*')
85+
p = ProcessPoolExecutor(max_workers=self.num_proc)
86+
res = list()
87+
88+
for index, path in enumerate(data_dict_paths):
89+
# # for test
90+
# if index >10:
91+
# break
92+
res.append(p.submit(self.process, index, path))
93+
94+
p.shutdown(wait=True)
95+
for future in res:
96+
print(future.result(), flush=True)
97+
98+
99+
def load_dataset(num_proc=4, **kargs):
100+
cache_dict_paths = Path(_CACHE_SPLIT_DATA_PATH).glob('*')
101+
ds = []
102+
res = []
103+
p = ProcessPoolExecutor(max_workers=num_proc)
104+
for path in cache_dict_paths:
105+
res.append(p.submit(datasets.load_from_disk,
106+
str(path), **kargs))
107+
108+
p.shutdown(wait=True)
109+
for future in res:
110+
ds.append(future.result())
111+
# print(future.result())
112+
train = []
113+
test = []
114+
validation = []
115+
for ds_ in ds:
116+
train.append(ds_['train'])
117+
test.append(ds_['test'])
118+
validation.append(ds_['validation'])
119+
# ds = datasets.concatenate_datasets(ds)
120+
# print(ds)
121+
return datasets.DatasetDict({
122+
'train': datasets.concatenate_datasets(train),
123+
'test': datasets.concatenate_datasets(test),
124+
'validation': datasets.concatenate_datasets(validation)
125+
})
126+
127+
128+
129+
if __name__ == '__main__':
130+
dataset = CommonTextCorpusGenerate(_SPLIT_DATA_PATH, num_proc=16)
131+
dataset.generate_cache_arrow()

0 commit comments

Comments
 (0)