1
+ import datasets
2
+ from pathlib import Path
3
+ from concurrent .futures import ProcessPoolExecutor
4
+
5
+
6
+ _SPLIT_DATA_PATH = '/cognitive_comp/common_data/big_corpus/ideaData/ideaData2.0'
7
+ # 缓存文件
8
+ _CACHE_SPLIT_DATA_PATH = '/cognitive_comp/common_data/ideadata2.0/'
9
+ # feats = datasets.Features({"text": datasets.Value('string')})
10
+
11
+
12
+ class CommonTextCorpusGenerate (object ):
13
+ """
14
+ 处理通用文本数据,输入格式为jsonline,例如每一行都是:{"text":"这是段文本的内容"}
15
+ 处理成memory map的数据类型,便于后续的加载
16
+ """
17
+ def __init__ (self ,
18
+ data_files = _SPLIT_DATA_PATH ,
19
+ save_path = _CACHE_SPLIT_DATA_PATH ,
20
+ train_test_validation = '9900,70,30' ,
21
+ num_proc = 1 ,
22
+ if_shuffle = False ,
23
+ cache = False ):
24
+ self .data_files = Path (data_files )
25
+ if save_path :
26
+ self .save_path = Path (save_path )
27
+ else :
28
+ self .save_path = self .file_check (
29
+ Path (self .data_files .parent , self .data_files .name + '_FSDataset' ),'save' )
30
+ self .num_proc = num_proc
31
+ self .cache = cache
32
+ self .split_idx = self .split_train_test_validation_index (train_test_validation )
33
+ self .shuffle = if_shuffle
34
+ if cache :
35
+ self .cache_path = self .file_check (
36
+ Path (self .save_path .parent , 'FSDataCache' , self .data_files .name ), 'cache' )
37
+ else :
38
+ self .cache_path = None
39
+
40
+ @staticmethod
41
+ def file_check (path , path_type ):
42
+ print (path )
43
+ if not path .exists ():
44
+ path .mkdir (parents = True )
45
+ print (f"Since no { path_type } directory is specified, the program will automatically create it in { path } directory." )
46
+ return str (path )
47
+
48
+ @staticmethod
49
+ def split_train_test_validation_index (train_test_validation ):
50
+ split_idx_ = [int (i ) for i in train_test_validation .split (',' )]
51
+ idx_dict = {
52
+ 'train_rate' : split_idx_ [0 ]/ sum (split_idx_ ),
53
+ 'test_rate' : split_idx_ [1 ]/ sum (split_idx_ [1 :])
54
+ }
55
+ return idx_dict
56
+
57
+ def process (self , index , path ):
58
+ print ('saving dataset shard {}' .format (index ))
59
+
60
+ ds = (datasets .load_dataset ('json' , data_files = str (path ),
61
+ cache_dir = self .cache_path ,
62
+ features = None ))
63
+ # 局部shuffle
64
+ # TODO 全局shuffle
65
+ if self .shuffle :
66
+ ds = ds .shuffle ()
67
+ # 添加分句操作 没加
68
+ # 这里是拆分了数据集,
69
+ ds = ds ['train' ].train_test_split (train_size = self .split_idx ['train_rate' ])
70
+ ds_ = ds ['test' ].train_test_split (train_size = self .split_idx ['test_rate' ])
71
+ ds = datasets .DatasetDict ({
72
+ 'train' : ds ['train' ],
73
+ 'test' : ds_ ['train' ],
74
+ 'validation' : ds_ ['test' ]
75
+ })
76
+
77
+ ds .save_to_disk (Path (self .save_path , path .name ))
78
+ return 'saving dataset shard {} done' .format (index )
79
+
80
+ def generate_cache_arrow (self ) -> None :
81
+ '''
82
+ 生成HF支持的缓存文件,加速后续的加载
83
+ '''
84
+ data_dict_paths = self .data_files .rglob ('*' )
85
+ p = ProcessPoolExecutor (max_workers = self .num_proc )
86
+ res = list ()
87
+
88
+ for index , path in enumerate (data_dict_paths ):
89
+ # # for test
90
+ # if index >10:
91
+ # break
92
+ res .append (p .submit (self .process , index , path ))
93
+
94
+ p .shutdown (wait = True )
95
+ for future in res :
96
+ print (future .result (), flush = True )
97
+
98
+
99
+ def load_dataset (num_proc = 4 , ** kargs ):
100
+ cache_dict_paths = Path (_CACHE_SPLIT_DATA_PATH ).glob ('*' )
101
+ ds = []
102
+ res = []
103
+ p = ProcessPoolExecutor (max_workers = num_proc )
104
+ for path in cache_dict_paths :
105
+ res .append (p .submit (datasets .load_from_disk ,
106
+ str (path ), ** kargs ))
107
+
108
+ p .shutdown (wait = True )
109
+ for future in res :
110
+ ds .append (future .result ())
111
+ # print(future.result())
112
+ train = []
113
+ test = []
114
+ validation = []
115
+ for ds_ in ds :
116
+ train .append (ds_ ['train' ])
117
+ test .append (ds_ ['test' ])
118
+ validation .append (ds_ ['validation' ])
119
+ # ds = datasets.concatenate_datasets(ds)
120
+ # print(ds)
121
+ return datasets .DatasetDict ({
122
+ 'train' : datasets .concatenate_datasets (train ),
123
+ 'test' : datasets .concatenate_datasets (test ),
124
+ 'validation' : datasets .concatenate_datasets (validation )
125
+ })
126
+
127
+
128
+
129
+ if __name__ == '__main__' :
130
+ dataset = CommonTextCorpusGenerate (_SPLIT_DATA_PATH , num_proc = 16 )
131
+ dataset .generate_cache_arrow ()
0 commit comments