1
1
#!/usr/bin/env python3
2
2
3
- # we put whole test code inside main guard because processes started by distributed client would start that file as well and raise errors
4
- if __name__ == "__main__" :
5
- exec (open ("./dask/groupby-dask2.py" ).read ())
3
+ import os
4
+ import gc
5
+ import sys
6
+ import timeit
7
+ import pandas as pd
8
+ import dask as dk
9
+ import dask .dataframe as dd
10
+ import logging
11
+ from abc import ABC , abstractmethod
12
+ from dask import distributed
13
+ from typing import Any
6
14
15
+ exec (open ("./_helpers/helpers.py" ).read ())
16
+
17
+ logging .basicConfig (
18
+ level = logging .INFO ,
19
+ format = '{ %(name)s:%(lineno)d @ %(asctime)s } - %(message)s'
20
+ )
21
+ logger = logging .getLogger (__name__ )
22
+
23
+ # TODO: Case
24
+ ver = dk .__version__
25
+ git = dk .__git_revision__
26
+ task = "groupby"
27
+ solution = "dask"
28
+ fun = ".groupby"
29
+ cache = "TRUE"
30
+
31
+ def dask_client () -> distributed .Client :
32
+ # we use process-pool instead of thread-pool due to GIL cost
33
+ return distributed .Client (processes = True , silence_logs = logging .ERROR )
34
+
35
+ def load_dataset (src_grp : str ) -> dd .DataFrame :
36
+ logger .info ("Loading dataset %s" % data_name )
37
+ x = dd .read_csv (
38
+ src_grp ,
39
+ dtype = {"id1" :"category" ,"id2" :"category" ,"id3" :"category" ,"id4" :"Int32" ,"id5" :"Int32" ,"id6" :"Int32" ,"v1" :"Int32" ,"v2" :"Int32" ,"v3" :"float64" },
40
+ engine = "pyarrow"
41
+ )
42
+ x = x .persist ()
43
+ return x
44
+
45
+ class Query (ABC ):
46
+ @staticmethod
47
+ @abstractmethod
48
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
49
+ pass
50
+
51
+ @staticmethod
52
+ @abstractmethod
53
+ def check (ans : dd .DataFrame ) -> Any :
54
+ pass
55
+
56
+ class QueryOne (Query ):
57
+ @staticmethod
58
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
59
+ ans = x .groupby ('id1' , dropna = False , observed = True ).agg ({'v1' :'sum' }).compute ()
60
+ ans .reset_index (inplace = True ) # #68
61
+ return ans
62
+
63
+ @staticmethod
64
+ def check (ans : dd .DataFrame ) -> Any :
65
+ return [ans .v1 .sum ()]
66
+
67
+ class QueryTwo (Query ):
68
+ @staticmethod
69
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
70
+ ans = x .groupby (['id1' ,'id2' ], dropna = False , observed = True ).agg ({'v1' :'sum' }).compute ()
71
+ ans .reset_index (inplace = True ) # #68
72
+ return ans
73
+
74
+ @staticmethod
75
+ def check (ans : dd .DataFrame ) -> Any :
76
+ return [ans .v1 .sum ()]
77
+
78
+ class QueryThree (Query ):
79
+ @staticmethod
80
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
81
+ ans = x .groupby ('id3' , dropna = False , observed = True ).agg ({'v1' :'sum' , 'v3' :'mean' }).compute ()
82
+ ans .reset_index (inplace = True ) # #68
83
+ return ans
84
+
85
+ @staticmethod
86
+ def check (ans : dd .DataFrame ) -> Any :
87
+ return [ans .v1 .sum (), ans .v3 .sum ()]
88
+
89
+ class QueryFour (Query ):
90
+ @staticmethod
91
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
92
+ ans = x .groupby ('id4' , dropna = False , observed = True ).agg ({'v1' :'mean' , 'v2' :'mean' , 'v3' :'mean' }).compute ()
93
+ ans .reset_index (inplace = True ) # #68
94
+ return ans
95
+
96
+ @staticmethod
97
+ def check (ans : dd .DataFrame ) -> Any :
98
+ return [ans .v1 .sum (), ans .v2 .sum (), ans .v3 .sum ()]
99
+
100
+ class QueryFive (Query ):
101
+ @staticmethod
102
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
103
+ ans = x .groupby ('id6' , dropna = False , observed = True ).agg ({'v1' :'sum' , 'v2' :'sum' , 'v3' :'sum' }).compute ()
104
+ ans .reset_index (inplace = True ) # #68
105
+ return ans
106
+
107
+ @staticmethod
108
+ def check (ans : dd .DataFrame ) -> Any :
109
+ return [ans .v1 .sum (), ans .v2 .sum (), ans .v3 .sum ()]
110
+
111
+ class QuerySix (Query ):
112
+ @staticmethod
113
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
114
+ ans = x .groupby (['id4' ,'id5' ], dropna = False , observed = True ).agg ({'v3' : ['median' ,'std' ]}, shuffle = 'p2p' ).compute ()
115
+ ans .reset_index (inplace = True ) # #68
116
+ return ans
117
+
118
+ @staticmethod
119
+ def check (ans : dd .DataFrame ) -> Any :
120
+ return [ans ['v3' ]['median' ].sum (), ans ['v3' ]['std' ].sum ()]
121
+
122
+ class QuerySeven (Query ):
123
+ @staticmethod
124
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
125
+ ans = x .groupby ('id3' , dropna = False , observed = True ).agg ({'v1' :'max' , 'v2' :'min' }).assign (range_v1_v2 = lambda x : x ['v1' ]- x ['v2' ])[['range_v1_v2' ]].compute ()
126
+ ans .reset_index (inplace = True ) # #68
127
+ return ans
128
+
129
+
130
+ @staticmethod
131
+ def check (ans : dd .DataFrame ) -> Any :
132
+ return [ans ['range_v1_v2' ].sum ()]
133
+
134
+ class QueryEight (Query ):
135
+ @staticmethod
136
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
137
+ ans = x [~ x ['v3' ].isna ()][['id6' ,'v3' ]].groupby ('id6' , dropna = False , observed = True ).apply (lambda x : x .nlargest (2 , columns = 'v3' ), meta = {'id6' :'Int64' , 'v3' :'float64' })[['v3' ]].compute ()
138
+ ans .reset_index (level = 'id6' , inplace = True )
139
+ ans .reset_index (drop = True , inplace = True ) # drop because nlargest creates some extra new index field
140
+ return ans
141
+
142
+ @staticmethod
143
+ def check (ans : dd .DataFrame ) -> Any :
144
+ return [ans ['v3' ].sum ()]
145
+
146
+ class QueryNine (Query ):
147
+ @staticmethod
148
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
149
+ ans = x [['id2' ,'id4' ,'v1' ,'v2' ]].groupby (['id2' ,'id4' ], dropna = False , observed = True ).apply (lambda x : pd .Series ({'r2' : x .corr ()['v1' ]['v2' ]** 2 }), meta = {'r2' :'float64' }).compute ()
150
+ ans .reset_index (inplace = True )
151
+ return ans
152
+
153
+ @staticmethod
154
+ def check (ans : dd .DataFrame ) -> Any :
155
+ return [ans ['r2' ].sum ()]
156
+
157
+ class QueryTen (Query ):
158
+ @staticmethod
159
+ def query (x : dd .DataFrame ) -> dd .DataFrame :
160
+ ans = (
161
+ x .groupby (
162
+ ['id1' , 'id2' , 'id3' , 'id4' , 'id5' , 'id6' ],
163
+ dropna = False ,
164
+ observed = True ,
165
+ )
166
+ .agg ({'v3' : 'sum' , 'v1' : 'size' }, split_out = x .npartitions )
167
+ .rename (columns = {"v1" : "count" })
168
+ .compute ()
169
+ )
170
+ ans .reset_index (inplace = True )
171
+ return ans
172
+
173
+ @staticmethod
174
+ def check (ans : dd .DataFrame ) -> Any :
175
+ return [ans .v3 .sum (), ans ["count" ].sum ()]
176
+
177
+ def run_query (
178
+ data_name : str ,
179
+ in_rows : int ,
180
+ x : dd .DataFrame ,
181
+ query : Query ,
182
+ question : str ,
183
+ runs : int = 2 ,
184
+ machine_type : str ,
185
+ ):
186
+ logger .info ("Running query: '%s'" % question )
187
+ try :
188
+ for run in range (1 , runs + 1 ):
189
+ gc .collect () # TODO: Able to do this in worker processes? Want to?
190
+
191
+ # Calculate ans
192
+ t_start = timeit .default_timer ()
193
+ ans = query .query (x )
194
+ logger .debug ("Answer shape: %s" % (ans .shape , ))
195
+ t = timeit .default_timer () - t_start
196
+ m = memory_usage ()
197
+
198
+ # Calculate chk
199
+ t_start = timeit .default_timer ()
200
+ chk = query .check (ans )
201
+ chkt = timeit .default_timer () - t_start
202
+
203
+
204
+ write_log (
205
+ task = task ,
206
+ data = data_name ,
207
+ in_rows = in_rows ,
208
+ question = question ,
209
+ out_rows = ans .shape [0 ],
210
+ out_cols = ans .shape [1 ],
211
+ solution = solution ,
212
+ version = ver ,
213
+ git = git ,
214
+ fun = fun ,
215
+ run = run ,
216
+ time_sec = t ,
217
+ mem_gb = m ,
218
+ cache = cache ,
219
+ chk = make_chk (chk ),
220
+ chk_time_sec = chkt ,
221
+ on_disk = on_disk ,
222
+ machine_type = machine_type
223
+ )
224
+ if run == runs :
225
+ # Print head / tail on last run
226
+ logger .debug ("Answer head:\n %s" % ans .head (3 ))
227
+ logger .debug ("Answer tail:\n %s" % ans .tail (3 ))
228
+ del ans
229
+ except Exception as err :
230
+ logger .error ("Query '%s' failed!" % question )
231
+ print (err )
232
+
233
+ def run_task (
234
+ data_name : str ,
235
+ src_grp : str ,
236
+ machine_type : str
237
+ ):
238
+ client = dask_client ()
239
+ x = load_dataset (src_grp )
240
+ in_rows = len (x )
241
+ logger .info ("Input dataset rows: %s" % in_rows )
242
+
243
+ task_init = timeit .default_timer ()
244
+ logger .info ("Grouping..." )
245
+
246
+ run_query (
247
+ data_name = data_name ,
248
+ in_rows = in_rows ,
249
+ x = x ,
250
+ query = QueryOne ,
251
+ question = "sum v1 by id1" , # q1
252
+ machine_type = machine_type ,
253
+ )
254
+
255
+ run_query (
256
+ data_name = data_name ,
257
+ in_rows = in_rows ,
258
+ x = x ,
259
+ query = QueryTwo ,
260
+ question = "sum v1 by id1:id2" , # q2
261
+ machine_type = machine_type ,
262
+ )
263
+
264
+ run_query (
265
+ data_name = data_name ,
266
+ in_rows = in_rows ,
267
+ x = x ,
268
+ query = QueryThree ,
269
+ question = "sum v1 mean v3 by id3" , # q3
270
+ machine_type = machine_type ,
271
+ )
272
+
273
+ run_query (
274
+ data_name = data_name ,
275
+ in_rows = in_rows ,
276
+ x = x ,
277
+ query = QueryFour ,
278
+ question = "mean v1:v3 by id4" , # q4
279
+ machine_type = machine_type ,
280
+ )
281
+
282
+ run_query (
283
+ data_name = data_name ,
284
+ in_rows = in_rows ,
285
+ x = x ,
286
+ query = QueryFive ,
287
+ question = "sum v1:v3 by id6" , # q5
288
+ machine_type = machine_type ,
289
+ )
290
+
291
+ run_query (
292
+ data_name = data_name ,
293
+ in_rows = in_rows ,
294
+ x = x ,
295
+ query = QuerySix ,
296
+ question = "median v3 sd v3 by id4 id5" , # q6
297
+ machine_type = machine_type ,
298
+ )
299
+
300
+ run_query (
301
+ data_name = data_name ,
302
+ in_rows = in_rows ,
303
+ x = x ,
304
+ query = QuerySeven ,
305
+ question = "max v1 - min v2 by id3" , # q7
306
+ machine_type = machine_type ,
307
+ )
308
+
309
+ run_query (
310
+ data_name = data_name ,
311
+ in_rows = in_rows ,
312
+ x = x ,
313
+ query = QueryEight ,
314
+ question = "largest two v3 by id6" , # q8
315
+ machine_type = machine_type ,
316
+ )
317
+
318
+ run_query (
319
+ data_name = data_name ,
320
+ in_rows = in_rows ,
321
+ x = x ,
322
+ query = QueryNine ,
323
+ question = "regression v1 v2 by id2 id4" , # q9
324
+ machine_type = machine_type ,
325
+ )
326
+
327
+ run_query (
328
+ data_name = data_name ,
329
+ in_rows = in_rows ,
330
+ x = x ,
331
+ query = QueryTen ,
332
+ question = "sum v3 count by id1:id6" , # q10
333
+ machine_type = machine_type ,
334
+ )
335
+
336
+ logger .info ("Grouping finished, took %0.fs" % (timeit .default_timer ()- task_init ))
337
+
338
+ if __name__ == '__main__' :
339
+ logger .info ("# groupby-dask.py" )
340
+ data_name = os .environ ['SRC_DATANAME' ]
341
+ machine_type = os .environ ['MACHINE_TYPE' ]
342
+ on_disk = False #data_name.split("_")[1] == "1e9" # on-disk data storage #126
343
+ on_disk = data_name .split ("_" )[1 ] == "1e9" and os .environ ["MACHINE_TYPE" ] == "c6id.4xlarge"
344
+ fext = "parquet" if on_disk else "csv"
345
+ src_grp = os .path .join ("data" , data_name + "." + fext )
346
+
347
+ na_flag = int (data_name .split ("_" )[3 ])
348
+ if na_flag > 0 :
349
+ logger .error ("Skip due to na_flag>0: #171" )
350
+ exit (0 ) # not yet implemented #171, currently groupby's dropna=False argument is ignored
351
+
352
+ run_task (
353
+ data_name = data_name ,
354
+ src_grp = src_grp ,
355
+ machine_type = machine_type
356
+ )
0 commit comments