Skip to content

Commit 3ffd0a8

Browse files
committed
Merge branch 'main' into add_new_machine_type
2 parents 8fa928f + ea8d1e9 commit 3ffd0a8

File tree

2 files changed

+353
-363
lines changed

2 files changed

+353
-363
lines changed

dask/groupby-dask.py

+353-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,356 @@
11
#!/usr/bin/env python3
22

3-
# we put whole test code inside main guard because processes started by distributed client would start that file as well and raise errors
4-
if __name__ == "__main__":
5-
exec(open("./dask/groupby-dask2.py").read())
3+
import os
4+
import gc
5+
import sys
6+
import timeit
7+
import pandas as pd
8+
import dask as dk
9+
import dask.dataframe as dd
10+
import logging
11+
from abc import ABC, abstractmethod
12+
from dask import distributed
13+
from typing import Any
614

15+
exec(open("./_helpers/helpers.py").read())
16+
17+
logging.basicConfig(
18+
level=logging.INFO,
19+
format='{ %(name)s:%(lineno)d @ %(asctime)s } - %(message)s'
20+
)
21+
logger = logging.getLogger(__name__)
22+
23+
# TODO: Case
24+
ver = dk.__version__
25+
git = dk.__git_revision__
26+
task = "groupby"
27+
solution = "dask"
28+
fun = ".groupby"
29+
cache = "TRUE"
30+
31+
def dask_client() -> distributed.Client:
32+
# we use process-pool instead of thread-pool due to GIL cost
33+
return distributed.Client(processes=True, silence_logs=logging.ERROR)
34+
35+
def load_dataset(src_grp: str) -> dd.DataFrame:
36+
logger.info("Loading dataset %s" % data_name)
37+
x = dd.read_csv(
38+
src_grp,
39+
dtype={"id1":"category","id2":"category","id3":"category","id4":"Int32","id5":"Int32","id6":"Int32","v1":"Int32","v2":"Int32","v3":"float64"},
40+
engine="pyarrow"
41+
)
42+
x = x.persist()
43+
return x
44+
45+
class Query(ABC):
46+
@staticmethod
47+
@abstractmethod
48+
def query(x: dd.DataFrame) -> dd.DataFrame:
49+
pass
50+
51+
@staticmethod
52+
@abstractmethod
53+
def check(ans: dd.DataFrame) -> Any:
54+
pass
55+
56+
class QueryOne(Query):
57+
@staticmethod
58+
def query(x: dd.DataFrame) -> dd.DataFrame:
59+
ans = x.groupby('id1', dropna=False, observed=True).agg({'v1':'sum'}).compute()
60+
ans.reset_index(inplace=True) # #68
61+
return ans
62+
63+
@staticmethod
64+
def check(ans: dd.DataFrame) -> Any:
65+
return [ans.v1.sum()]
66+
67+
class QueryTwo(Query):
68+
@staticmethod
69+
def query(x: dd.DataFrame) -> dd.DataFrame:
70+
ans = x.groupby(['id1','id2'], dropna=False, observed=True).agg({'v1':'sum'}).compute()
71+
ans.reset_index(inplace=True) # #68
72+
return ans
73+
74+
@staticmethod
75+
def check(ans: dd.DataFrame) -> Any:
76+
return [ans.v1.sum()]
77+
78+
class QueryThree(Query):
79+
@staticmethod
80+
def query(x: dd.DataFrame) -> dd.DataFrame:
81+
ans = x.groupby('id3', dropna=False, observed=True).agg({'v1':'sum', 'v3':'mean'}).compute()
82+
ans.reset_index(inplace=True) # #68
83+
return ans
84+
85+
@staticmethod
86+
def check(ans: dd.DataFrame) -> Any:
87+
return [ans.v1.sum(), ans.v3.sum()]
88+
89+
class QueryFour(Query):
90+
@staticmethod
91+
def query(x: dd.DataFrame) -> dd.DataFrame:
92+
ans = x.groupby('id4', dropna=False, observed=True).agg({'v1':'mean', 'v2':'mean', 'v3':'mean'}).compute()
93+
ans.reset_index(inplace=True) # #68
94+
return ans
95+
96+
@staticmethod
97+
def check(ans: dd.DataFrame) -> Any:
98+
return [ans.v1.sum(), ans.v2.sum(), ans.v3.sum()]
99+
100+
class QueryFive(Query):
101+
@staticmethod
102+
def query(x: dd.DataFrame) -> dd.DataFrame:
103+
ans = x.groupby('id6', dropna=False, observed=True).agg({'v1':'sum', 'v2':'sum', 'v3':'sum'}).compute()
104+
ans.reset_index(inplace=True) # #68
105+
return ans
106+
107+
@staticmethod
108+
def check(ans: dd.DataFrame) -> Any:
109+
return [ans.v1.sum(), ans.v2.sum(), ans.v3.sum()]
110+
111+
class QuerySix(Query):
112+
@staticmethod
113+
def query(x: dd.DataFrame) -> dd.DataFrame:
114+
ans = x.groupby(['id4','id5'], dropna=False, observed=True).agg({'v3': ['median','std']}, shuffle='p2p').compute()
115+
ans.reset_index(inplace=True) # #68
116+
return ans
117+
118+
@staticmethod
119+
def check(ans: dd.DataFrame) -> Any:
120+
return [ans['v3']['median'].sum(), ans['v3']['std'].sum()]
121+
122+
class QuerySeven(Query):
123+
@staticmethod
124+
def query(x: dd.DataFrame) -> dd.DataFrame:
125+
ans = x.groupby('id3', dropna=False, observed=True).agg({'v1':'max', 'v2':'min'}).assign(range_v1_v2=lambda x: x['v1']-x['v2'])[['range_v1_v2']].compute()
126+
ans.reset_index(inplace=True) # #68
127+
return ans
128+
129+
130+
@staticmethod
131+
def check(ans: dd.DataFrame) -> Any:
132+
return [ans['range_v1_v2'].sum()]
133+
134+
class QueryEight(Query):
135+
@staticmethod
136+
def query(x: dd.DataFrame) -> dd.DataFrame:
137+
ans = x[~x['v3'].isna()][['id6','v3']].groupby('id6', dropna=False, observed=True).apply(lambda x: x.nlargest(2, columns='v3'), meta={'id6':'Int64', 'v3':'float64'})[['v3']].compute()
138+
ans.reset_index(level='id6', inplace=True)
139+
ans.reset_index(drop=True, inplace=True) # drop because nlargest creates some extra new index field
140+
return ans
141+
142+
@staticmethod
143+
def check(ans: dd.DataFrame) -> Any:
144+
return [ans['v3'].sum()]
145+
146+
class QueryNine(Query):
147+
@staticmethod
148+
def query(x: dd.DataFrame) -> dd.DataFrame:
149+
ans = x[['id2','id4','v1','v2']].groupby(['id2','id4'], dropna=False, observed=True).apply(lambda x: pd.Series({'r2': x.corr()['v1']['v2']**2}), meta={'r2':'float64'}).compute()
150+
ans.reset_index(inplace=True)
151+
return ans
152+
153+
@staticmethod
154+
def check(ans: dd.DataFrame) -> Any:
155+
return [ans['r2'].sum()]
156+
157+
class QueryTen(Query):
158+
@staticmethod
159+
def query(x: dd.DataFrame) -> dd.DataFrame:
160+
ans = (
161+
x.groupby(
162+
['id1', 'id2', 'id3', 'id4', 'id5', 'id6'],
163+
dropna=False,
164+
observed=True,
165+
)
166+
.agg({'v3': 'sum', 'v1': 'size'}, split_out=x.npartitions)
167+
.rename(columns={"v1": "count"})
168+
.compute()
169+
)
170+
ans.reset_index(inplace=True)
171+
return ans
172+
173+
@staticmethod
174+
def check(ans: dd.DataFrame) -> Any:
175+
return [ans.v3.sum(), ans["count"].sum()]
176+
177+
def run_query(
178+
data_name: str,
179+
in_rows: int,
180+
x: dd.DataFrame,
181+
query: Query,
182+
question: str,
183+
runs: int = 2,
184+
machine_type: str,
185+
):
186+
logger.info("Running query: '%s'" % question)
187+
try:
188+
for run in range(1, runs+1):
189+
gc.collect() # TODO: Able to do this in worker processes? Want to?
190+
191+
# Calculate ans
192+
t_start = timeit.default_timer()
193+
ans = query.query(x)
194+
logger.debug("Answer shape: %s" % (ans.shape, ))
195+
t = timeit.default_timer() - t_start
196+
m = memory_usage()
197+
198+
# Calculate chk
199+
t_start = timeit.default_timer()
200+
chk = query.check(ans)
201+
chkt = timeit.default_timer() - t_start
202+
203+
204+
write_log(
205+
task=task,
206+
data=data_name,
207+
in_rows=in_rows,
208+
question=question,
209+
out_rows=ans.shape[0],
210+
out_cols=ans.shape[1],
211+
solution=solution,
212+
version=ver,
213+
git=git,
214+
fun=fun,
215+
run=run,
216+
time_sec=t,
217+
mem_gb=m,
218+
cache=cache,
219+
chk=make_chk(chk),
220+
chk_time_sec=chkt,
221+
on_disk=on_disk,
222+
machine_type=machine_type
223+
)
224+
if run == runs:
225+
# Print head / tail on last run
226+
logger.debug("Answer head:\n%s" % ans.head(3))
227+
logger.debug("Answer tail:\n%s" % ans.tail(3))
228+
del ans
229+
except Exception as err:
230+
logger.error("Query '%s' failed!" % question)
231+
print(err)
232+
233+
def run_task(
234+
data_name: str,
235+
src_grp: str,
236+
machine_type: str
237+
):
238+
client = dask_client()
239+
x = load_dataset(src_grp)
240+
in_rows = len(x)
241+
logger.info("Input dataset rows: %s" % in_rows)
242+
243+
task_init = timeit.default_timer()
244+
logger.info("Grouping...")
245+
246+
run_query(
247+
data_name=data_name,
248+
in_rows=in_rows,
249+
x=x,
250+
query=QueryOne,
251+
question="sum v1 by id1", # q1
252+
machine_type=machine_type,
253+
)
254+
255+
run_query(
256+
data_name=data_name,
257+
in_rows=in_rows,
258+
x=x,
259+
query=QueryTwo,
260+
question="sum v1 by id1:id2", # q2
261+
machine_type=machine_type,
262+
)
263+
264+
run_query(
265+
data_name=data_name,
266+
in_rows=in_rows,
267+
x=x,
268+
query=QueryThree,
269+
question="sum v1 mean v3 by id3", # q3
270+
machine_type=machine_type,
271+
)
272+
273+
run_query(
274+
data_name=data_name,
275+
in_rows=in_rows,
276+
x=x,
277+
query=QueryFour,
278+
question="mean v1:v3 by id4", # q4
279+
machine_type=machine_type,
280+
)
281+
282+
run_query(
283+
data_name=data_name,
284+
in_rows=in_rows,
285+
x=x,
286+
query=QueryFive,
287+
question= "sum v1:v3 by id6", # q5
288+
machine_type=machine_type,
289+
)
290+
291+
run_query(
292+
data_name=data_name,
293+
in_rows=in_rows,
294+
x=x,
295+
query=QuerySix,
296+
question="median v3 sd v3 by id4 id5", # q6
297+
machine_type=machine_type,
298+
)
299+
300+
run_query(
301+
data_name=data_name,
302+
in_rows=in_rows,
303+
x=x,
304+
query=QuerySeven,
305+
question="max v1 - min v2 by id3", # q7
306+
machine_type=machine_type,
307+
)
308+
309+
run_query(
310+
data_name=data_name,
311+
in_rows=in_rows,
312+
x=x,
313+
query=QueryEight,
314+
question="largest two v3 by id6", # q8
315+
machine_type=machine_type,
316+
)
317+
318+
run_query(
319+
data_name=data_name,
320+
in_rows=in_rows,
321+
x=x,
322+
query=QueryNine,
323+
question="regression v1 v2 by id2 id4", # q9
324+
machine_type=machine_type,
325+
)
326+
327+
run_query(
328+
data_name=data_name,
329+
in_rows=in_rows,
330+
x=x,
331+
query=QueryTen,
332+
question= "sum v3 count by id1:id6", # q10
333+
machine_type=machine_type,
334+
)
335+
336+
logger.info("Grouping finished, took %0.fs" % (timeit.default_timer()-task_init))
337+
338+
if __name__ == '__main__':
339+
logger.info("# groupby-dask.py")
340+
data_name = os.environ['SRC_DATANAME']
341+
machine_type = os.environ['MACHINE_TYPE']
342+
on_disk = False #data_name.split("_")[1] == "1e9" # on-disk data storage #126
343+
on_disk = data_name.split("_")[1] == "1e9" and os.environ["MACHINE_TYPE"] == "c6id.4xlarge"
344+
fext = "parquet" if on_disk else "csv"
345+
src_grp = os.path.join("data", data_name+"."+fext)
346+
347+
na_flag = int(data_name.split("_")[3])
348+
if na_flag > 0:
349+
logger.error("Skip due to na_flag>0: #171")
350+
exit(0) # not yet implemented #171, currently groupby's dropna=False argument is ignored
351+
352+
run_task(
353+
data_name=data_name,
354+
src_grp=src_grp,
355+
machine_type=machine_type
356+
)

0 commit comments

Comments
 (0)