6
6
_SPLIT_DATA_PATH = ''
7
7
# 缓存文件
8
8
_CACHE_TRAIN_DATA_PATH = '/cognitive_comp/common_data/wudao_180g_t5_tokenized_512/'
9
+ _CACHE_TRAIN_DATA_PATH_TRAIN = '/cognitive_comp/common_data/wudao_180g_t5_tokenized_512_train/'
10
+ _CACHE_TRAIN_DATA_PATH_TEST = '/cognitive_comp/common_data/wudao_180g_t5_tokenized_512_test/'
9
11
10
12
11
13
feats = datasets .Features ({"input_ids" : datasets .Value ('int32' )})
12
14
13
15
14
- def load_dataset (num_proc = 1 , ** kargs ):
16
+ def load_old_dataset (num_proc = 1 , ** kargs ):
15
17
cache_dict_paths = glob .glob (os .path .join (_CACHE_TRAIN_DATA_PATH , '*' ))
16
18
ds = []
17
19
res = []
@@ -25,27 +27,62 @@ def load_dataset(num_proc=1, **kargs):
25
27
ds .append (future .result ())
26
28
return datasets .DatasetDict ({"train" : datasets .concatenate_datasets (ds )})
27
29
30
+ def load_dataset (num_proc = 1 , ** kargs ):
31
+ '''
32
+ 加载缓存的数据
33
+ '''
34
+ cache_dict_paths = glob .glob (os .path .join (_CACHE_TRAIN_DATA_PATH_TRAIN , '*' ))
35
+ ds = []
36
+ res = []
37
+ p = ProcessPoolExecutor (max_workers = num_proc )
38
+ for path in cache_dict_paths :
39
+ res .append (p .submit (datasets .load_from_disk ,
40
+ path , ** kargs ))
41
+
42
+ p .shutdown (wait = True )
43
+ for future in res :
44
+ ds .append (future .result ())
45
+ train_ds = datasets .concatenate_datasets (ds )
46
+ test_ds = datasets .load_from_disk (_CACHE_TRAIN_DATA_PATH_TEST )
47
+ return datasets .DatasetDict ({
48
+ "train" : train_ds ,
49
+ "test" : test_ds })
50
+
28
51
29
- def _generate_cache_arrow (index , path ):
52
+ def _generate_cache_arrow (index , ds ):
30
53
print ('saving dataset shard {}' .format (index ))
31
- ds = (datasets .load_dataset ('json' , data_files = path ,
32
- cache_dir = '' ,
33
- features = feats )['train' ])
34
- ds .save_to_disk (os .path .join (_CACHE_TRAIN_DATA_PATH , os .path .basename (path )))
54
+ ds .save_to_disk (os .path .join (_CACHE_TRAIN_DATA_PATH_TRAIN , 'part_{}' .format (index )))
35
55
return 'saving dataset shard {} done' .format (index )
36
56
37
57
38
- def generate_cache_arrow (num_proc = 1 ) -> None :
58
+ def generate_arrow_cache (num_proc = 1 ) -> None :
39
59
'''
40
- 生成HF支持的缓存文件,加速后续的加载
60
+ 读取wudao_180g_t5_tokenized_512数据,并进行train test split
61
+ 同时利用seed 42做shuffle 缓存下来
41
62
'''
42
- data_dict_paths = glob .glob (_SPLIT_DATA_PATH )
63
+ ds = load_old_dataset (num_proc = num_proc )
64
+ ds = ds ['train' ].train_test_split (train_size = 0.999 , test_size = 0.001 , seed = 42 )
65
+ print (ds )
43
66
p = ProcessPoolExecutor (max_workers = num_proc )
44
67
res = []
45
-
46
- for index , path in enumerate (data_dict_paths ):
47
- res .append (p .submit (_generate_cache_arrow , index , path ))
68
+ train_shard_part = 800
69
+ for i in range (0 , train_shard_part ):
70
+ res .append (p .submit (_generate_cache_arrow , i ,
71
+ ds ['train' ].shard (train_shard_part , i )))
48
72
49
73
p .shutdown (wait = True )
50
74
for future in res :
51
75
print (future .result (), flush = True )
76
+
77
+ ds ['test' ].save_to_disk (_CACHE_TRAIN_DATA_PATH_TEST )
78
+ print ('done' )
79
+
80
+
81
+ if __name__ == '__main__' :
82
+ ds = load_dataset (num_proc = 100 )
83
+ print (ds )
84
+ # generate_arrow_cache(num_proc=100)
85
+
86
+
87
+
88
+
0 commit comments