Skip to content

Commit 1b3df97

Browse files
committed
update wudao_180g_t5_tokenized_512
1 parent e28b46f commit 1b3df97

File tree

1 file changed

+49
-12
lines changed

1 file changed

+49
-12
lines changed

wudao_180g_t5_tokenized_512/load.py

+49-12
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
_SPLIT_DATA_PATH = ''
77
# 缓存文件
88
_CACHE_TRAIN_DATA_PATH = '/cognitive_comp/common_data/wudao_180g_t5_tokenized_512/'
9+
_CACHE_TRAIN_DATA_PATH_TRAIN = '/cognitive_comp/common_data/wudao_180g_t5_tokenized_512_train/'
10+
_CACHE_TRAIN_DATA_PATH_TEST = '/cognitive_comp/common_data/wudao_180g_t5_tokenized_512_test/'
911

1012

1113
feats = datasets.Features({"input_ids": datasets.Value('int32')})
1214

1315

14-
def load_dataset(num_proc=1, **kargs):
16+
def load_old_dataset(num_proc=1, **kargs):
1517
cache_dict_paths = glob.glob(os.path.join(_CACHE_TRAIN_DATA_PATH, '*'))
1618
ds = []
1719
res = []
@@ -25,27 +27,62 @@ def load_dataset(num_proc=1, **kargs):
2527
ds.append(future.result())
2628
return datasets.DatasetDict({"train": datasets.concatenate_datasets(ds)})
2729

30+
def load_dataset(num_proc=1, **kargs):
31+
'''
32+
加载缓存的数据
33+
'''
34+
cache_dict_paths = glob.glob(os.path.join(_CACHE_TRAIN_DATA_PATH_TRAIN, '*'))
35+
ds = []
36+
res = []
37+
p = ProcessPoolExecutor(max_workers=num_proc)
38+
for path in cache_dict_paths:
39+
res.append(p.submit(datasets.load_from_disk,
40+
path, **kargs))
41+
42+
p.shutdown(wait=True)
43+
for future in res:
44+
ds.append(future.result())
45+
train_ds = datasets.concatenate_datasets(ds)
46+
test_ds = datasets.load_from_disk(_CACHE_TRAIN_DATA_PATH_TEST)
47+
return datasets.DatasetDict({
48+
"train": train_ds,
49+
"test": test_ds})
50+
2851

29-
def _generate_cache_arrow(index, path):
52+
def _generate_cache_arrow(index, ds):
3053
print('saving dataset shard {}'.format(index))
31-
ds = (datasets.load_dataset('json', data_files=path,
32-
cache_dir='',
33-
features=feats)['train'])
34-
ds.save_to_disk(os.path.join(_CACHE_TRAIN_DATA_PATH, os.path.basename(path)))
54+
ds.save_to_disk(os.path.join(_CACHE_TRAIN_DATA_PATH_TRAIN, 'part_{}'.format(index)))
3555
return 'saving dataset shard {} done'.format(index)
3656

3757

38-
def generate_cache_arrow(num_proc=1) -> None:
58+
def generate_arrow_cache(num_proc=1) -> None:
3959
'''
40-
生成HF支持的缓存文件,加速后续的加载
60+
读取wudao_180g_t5_tokenized_512数据,并进行train test split
61+
同时利用seed 42做shuffle 缓存下来
4162
'''
42-
data_dict_paths = glob.glob(_SPLIT_DATA_PATH)
63+
ds = load_old_dataset(num_proc=num_proc)
64+
ds = ds['train'].train_test_split(train_size=0.999, test_size=0.001, seed=42)
65+
print(ds)
4366
p = ProcessPoolExecutor(max_workers=num_proc)
4467
res = []
45-
46-
for index, path in enumerate(data_dict_paths):
47-
res.append(p.submit(_generate_cache_arrow, index, path))
68+
train_shard_part = 800
69+
for i in range(0, train_shard_part):
70+
res.append(p.submit(_generate_cache_arrow, i,
71+
ds['train'].shard(train_shard_part, i)))
4872

4973
p.shutdown(wait=True)
5074
for future in res:
5175
print(future.result(), flush=True)
76+
77+
ds['test'].save_to_disk(_CACHE_TRAIN_DATA_PATH_TEST)
78+
print('done')
79+
80+
81+
if __name__ == '__main__':
82+
ds = load_dataset(num_proc=100)
83+
print(ds)
84+
# generate_arrow_cache(num_proc=100)
85+
86+
87+
88+

0 commit comments

Comments
 (0)