Skip to content

Commit a094b9c

Browse files
committed
Merge pull request #501 from uwescience/bmyerz/partition-optimize-groupbys
Optimize groupby and distinct according to partitioning
2 parents 6f76965 + ca6513f commit a094b9c

File tree

4 files changed

+269
-33
lines changed

4 files changed

+269
-33
lines changed

raco/backends/myria/myria.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
import itertools
22
import logging
33
from collections import defaultdict, deque
4+
from functools import reduce
45
from operator import mul
56

67
from sqlalchemy.dialects import postgresql
78

89
from raco import algebra, expression, rules, scheme
9-
from raco.algebra import convertcondition
10+
from raco import types
1011
from raco.algebra import Shuffle
11-
from raco.catalog import Catalog
12-
from raco.representation import RepresentationProperties
12+
from raco.algebra import convertcondition
1313
from raco.backends import Language, Algebra
14-
from raco.backends.sql.catalog import SQLCatalog
15-
from raco.expression import WORKERID, COUNTALL
16-
from raco.expression import UnnamedAttributeRef
14+
from raco.backends.sql.catalog import SQLCatalog, PostgresSQLFunctionProvider
15+
from raco.catalog import Catalog
1716
from raco.datastructure.UnionFind import UnionFind
18-
from raco import types
19-
from raco.rules import distributed_group_by
20-
from functools import reduce
17+
from raco.expression import UnnamedAttributeRef
18+
from raco.expression import WORKERID, COUNTALL
19+
from raco.representation import RepresentationProperties
20+
from raco.rules import distributed_group_by, check_partition_equality
2121

2222
LOGGER = logging.getLogger(__name__)
2323

@@ -1133,17 +1133,6 @@ def fire(self, exp):
11331133
return exp
11341134

11351135

1136-
def check_partition_equality(op, representation):
1137-
"""Check to see if the operator has the required hash partitioning.
1138-
@param op operator
1139-
@param representation list of columns hash partitioned by,
1140-
in the unnamed perspective
1141-
@return true if the op has an equal hash partitioning to representation
1142-
"""
1143-
1144-
return op.partitioning().hash_partitioned == frozenset(representation)
1145-
1146-
11471136
class ShuffleBeforeSetop(rules.Rule):
11481137

11491138
def fire(self, exp):
@@ -1481,14 +1470,16 @@ def fire(self, op):
14811470

14821471
class PushIntoSQL(rules.Rule):
14831472

1484-
def __init__(self, dialect=None):
1473+
def __init__(self, dialect=None, push_grouping=False):
14851474
self.dialect = dialect or postgresql.dialect()
1475+
self.push_grouping = push_grouping
14861476
super(PushIntoSQL, self).__init__()
14871477

14881478
def fire(self, expr):
14891479
if isinstance(expr, (algebra.Scan, algebra.ScanTemp)):
14901480
return expr
1491-
cat = SQLCatalog()
1481+
cat = SQLCatalog(provider=PostgresSQLFunctionProvider(),
1482+
push_grouping=self.push_grouping)
14921483
try:
14931484
sql_plan = cat.get_sql(expr)
14941485
sql_string = sql_plan.compile(dialect=self.dialect)
@@ -1713,7 +1704,9 @@ def opt_rules(self, **kwargs):
17131704

17141705
if kwargs.get('push_sql', False):
17151706
opt_grps_sequence.append([
1716-
PushIntoSQL(dialect=kwargs.get('dialect'))])
1707+
PushIntoSQL(dialect=kwargs.get('dialect'),
1708+
push_grouping=kwargs.get(
1709+
'push_sql_grouping', False))])
17171710

17181711
compile_grps_sequence = [
17191712
myriafy,

raco/backends/sql/catalog.py

+51-7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import raco.scheme as scheme
1313
import raco.types as types
1414
from raco.representation import RepresentationProperties
15+
import abc
1516

1617

1718
type_to_raco = {Integer: types.LONG_TYPE,
@@ -30,9 +31,46 @@
3031
types.DATETIME_TYPE: DateTime}
3132

3233

34+
class SQLFunctionProvider(object):
35+
"""Interface for translating function names. For Raco functions
36+
not understood by SQLAlchemy, like stdev, we cannot rely
37+
on SQLAlchemy's compiler to translate function
38+
names to the given dialect.
39+
For functions not understood by SQLAlchemy, the SQLAlchemy compiler
40+
just emits them verbatim."""
41+
42+
__metaclass__ = abc.ABCMeta
43+
44+
@abc.abstractmethod
45+
def convert_unary_expr(self, expr, input):
46+
pass
47+
48+
49+
class _DefaultSQLFunctionProvider(SQLFunctionProvider):
50+
def convert_unary_expr(self, expr, input):
51+
# just use the function name without complaining
52+
fname = expr.__class__.__name__.lower()
53+
return getattr(func, fname)(input)
54+
55+
56+
class PostgresSQLFunctionProvider(SQLFunctionProvider):
57+
def convert_unary_expr(self, expr, input):
58+
fname = expr.__class__.__name__.lower()
59+
60+
# replacements
61+
if fname == "stdev":
62+
return func.stddev_samp(input)
63+
64+
# Warning: may create some functions not available in Postgres
65+
return getattr(func, fname)(input)
66+
67+
3368
class SQLCatalog(Catalog):
34-
def __init__(self, engine=None):
69+
def __init__(self, engine=None, push_grouping=False,
70+
provider=_DefaultSQLFunctionProvider()):
3571
self.engine = engine
72+
self.push_grouping = push_grouping
73+
self.provider = provider
3674
self.metadata = MetaData()
3775

3876
@staticmethod
@@ -106,10 +144,11 @@ def _convert_zeroary_expr(self, cols, expr, input_scheme):
106144

107145
def _convert_unary_expr(self, cols, expr, input_scheme):
108146
input = self._convert_expr(cols, expr.input, input_scheme)
109-
if isinstance(expr, expression.MAX):
110-
return func.max(input)
111-
if isinstance(expr, expression.MIN):
112-
return func.min(input)
147+
148+
c = self.provider.convert_unary_expr(expr, input)
149+
if c is not None:
150+
return c
151+
113152
raise NotImplementedError("expression {} to sql".format(type(expr)))
114153

115154
def _convert_binary_expr(self, cols, expr, input_scheme):
@@ -157,9 +196,11 @@ def _get_unary_sql(self, plan):
157196
return select(clause, from_obj=input)
158197

159198
elif isinstance(plan, algebra.GroupBy):
160-
if len(plan.grouping_list) > 0:
199+
if (not self.push_grouping) and len(plan.grouping_list) > 0:
161200
raise NotImplementedError(
162-
"convert aggregate with grouping to sql -- Myria faster")
201+
"""convert aggregate with grouping to sql
202+
-- Myria is faster. If you want to push group by into
203+
SQL use the flag push_sql_grouping""")
163204
a = [self._convert_expr(cols, e, input_sch)
164205
for e in plan.aggregate_list]
165206
g = [self._convert_expr(cols, e, input_sch)
@@ -169,6 +210,9 @@ def _get_unary_sql(self, plan):
169210
return sel
170211
return sel.group_by(*g)
171212

213+
elif isinstance(plan, algebra.Distinct):
214+
return select(['*'], from_obj=input, distinct=True)
215+
172216
raise NotImplementedError("convert {op} to sql".format(op=type(plan)))
173217

174218
def _get_binary_sql(self, plan):

raco/myrial/optimizer_tests.py

+172-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import collections
22
import random
3+
import sys
4+
import re
35

46
from raco.algebra import *
57
from raco.expression import NamedAttributeRef as AttRef
68
from raco.expression import UnnamedAttributeRef as AttIndex
79
from raco.expression import StateVar
10+
from raco.expression import aggregate
811

912
from raco.backends.myria import (
1013
MyriaShuffleConsumer, MyriaShuffleProducer, MyriaHyperShuffleProducer,
11-
MyriaBroadcastConsumer, MyriaQueryScan, MyriaSplitConsumer)
14+
MyriaBroadcastConsumer, MyriaQueryScan, MyriaSplitConsumer, MyriaDupElim,
15+
MyriaGroupBy)
1216
from raco.backends.myria import (MyriaLeftDeepTreeAlgebra,
1317
MyriaHyperCubeAlgebra)
1418
from raco.compile import optimize
@@ -1003,3 +1007,170 @@ def test_projecting_join_maintains_partitioning(self):
10031007
# (in general, info could be h($0) && h($2)
10041008
self.assertEquals(pp.partitioning().hash_partitioned,
10051009
frozenset([AttIndex(0)]))
1010+
1011+
def test_no_shuffle_for_partitioned_distinct(self):
1012+
"""Do not shuffle for Distinct if already partitioned"""
1013+
1014+
query = """
1015+
r = scan({part});
1016+
t = select distinct r.h from r;
1017+
store(t, OUTPUT);""".format(part=self.part_key)
1018+
1019+
lp = self.get_logical_plan(query)
1020+
pp = self.logical_to_physical(lp)
1021+
1022+
# shuffles should be removed and distinct not decomposed into two
1023+
self.assertEquals(self.get_count(pp, MyriaShuffleConsumer), 0)
1024+
self.assertEquals(self.get_count(pp, MyriaShuffleProducer), 0)
1025+
self.assertEquals(self.get_count(pp, MyriaDupElim), 1)
1026+
1027+
self.db.evaluate(pp)
1028+
result = self.db.get_table('OUTPUT')
1029+
expected = dict([((h,), 1) for _, h, _ in self.part_data])
1030+
self.assertEquals(result, expected)
1031+
1032+
def test_no_shuffle_for_partitioned_groupby(self):
1033+
"""Do not shuffle for groupby if already partitioned"""
1034+
1035+
query = """
1036+
r = scan({part});
1037+
t = select r.h, MIN(r.i) from r;
1038+
store(t, OUTPUT);""".format(part=self.part_key)
1039+
1040+
lp = self.get_logical_plan(query)
1041+
pp = self.logical_to_physical(lp)
1042+
1043+
# shuffles should be removed and the groupby not decomposed into two
1044+
self.assertEquals(self.get_count(pp, MyriaShuffleConsumer), 0)
1045+
self.assertEquals(self.get_count(pp, MyriaShuffleProducer), 0)
1046+
self.assertEquals(self.get_count(pp, MyriaGroupBy), 1)
1047+
1048+
def test_partition_aware_groupby_into_sql(self):
1049+
"""No shuffle for groupby also causes it to be pushed into sql"""
1050+
1051+
query = """
1052+
r = scan({part});
1053+
t = select r.h, MIN(r.i) from r;
1054+
store(t, OUTPUT);""".format(part=self.part_key)
1055+
1056+
lp = self.get_logical_plan(query)
1057+
pp = self.logical_to_physical(lp, push_sql=True,
1058+
push_sql_grouping=True)
1059+
1060+
# shuffles should be removed and the groupby not decomposed into two
1061+
self.assertEquals(self.get_count(pp, MyriaShuffleConsumer), 0)
1062+
self.assertEquals(self.get_count(pp, MyriaShuffleProducer), 0)
1063+
1064+
# should be pushed
1065+
self.assertEquals(self.get_count(pp, MyriaGroupBy), 0)
1066+
self.assertEquals(self.get_count(pp, MyriaQueryScan), 1)
1067+
1068+
self.db.evaluate(pp)
1069+
result = self.db.get_table('OUTPUT')
1070+
temp = dict([(h, sys.maxsize) for _, h, _ in self.part_data])
1071+
for _, h, i in self.part_data:
1072+
temp[h] = min(temp[h], i)
1073+
expected = dict(((h, i), 1) for h, i in temp.items())
1074+
1075+
self.assertEquals(result, expected)
1076+
1077+
def test_partition_aware_distinct_into_sql(self):
1078+
"""No shuffle for distinct also causes it to be pushed into sql"""
1079+
1080+
query = """
1081+
r = scan({part});
1082+
t = select distinct r.h from r;
1083+
store(t, OUTPUT);""".format(part=self.part_key)
1084+
1085+
lp = self.get_logical_plan(query)
1086+
pp = self.logical_to_physical(lp, push_sql=True)
1087+
1088+
# shuffles should be removed and the groupby not decomposed into two
1089+
self.assertEquals(self.get_count(pp, MyriaShuffleConsumer), 0)
1090+
self.assertEquals(self.get_count(pp, MyriaShuffleProducer), 0)
1091+
1092+
# should be pushed
1093+
self.assertEquals(self.get_count(pp, MyriaGroupBy), 0) # sanity
1094+
self.assertEquals(self.get_count(pp, MyriaDupElim), 0)
1095+
self.assertEquals(self.get_count(pp, MyriaQueryScan), 1)
1096+
1097+
self.db.evaluate(pp)
1098+
result = self.db.get_table('OUTPUT')
1099+
expected = dict([((h,), 1) for _, h, _ in self.part_data])
1100+
self.assertEquals(result, expected)
1101+
1102+
def test_push_half_groupby_into_sql(self):
1103+
"""Push the first group by of decomposed group by into sql"""
1104+
1105+
query = """
1106+
r = scan({part});
1107+
t = select r.i, MIN(r.h) from r;
1108+
store(t, OUTPUT);""".format(part=self.part_key)
1109+
1110+
lp = self.get_logical_plan(query)
1111+
pp = self.logical_to_physical(lp, push_sql=True,
1112+
push_sql_grouping=True)
1113+
1114+
# wrong partition, so still has shuffle
1115+
self.assertEquals(self.get_count(pp, MyriaShuffleConsumer), 1)
1116+
self.assertEquals(self.get_count(pp, MyriaShuffleProducer), 1)
1117+
1118+
# one group by should be pushed
1119+
self.assertEquals(self.get_count(pp, MyriaGroupBy), 1)
1120+
self.assertEquals(self.get_count(pp, MyriaQueryScan), 1)
1121+
1122+
self.db.evaluate(pp)
1123+
result = self.db.get_table('OUTPUT')
1124+
temp = dict([(i, sys.maxsize) for _, _, i in self.part_data])
1125+
for _, h, i in self.part_data:
1126+
temp[i] = min(temp[i], h)
1127+
expected = dict(((k, v), 1) for k, v in temp.items())
1128+
1129+
self.assertEquals(result, expected)
1130+
1131+
def _check_aggregate_functions_pushed(
1132+
self,
1133+
func,
1134+
expected,
1135+
override=False):
1136+
if override:
1137+
agg = func
1138+
else:
1139+
agg = "{func}(r.i)".format(func=func)
1140+
1141+
query = """
1142+
r = scan({part});
1143+
t = select r.h, {agg} from r;
1144+
store(t, OUTPUT);""".format(part=self.part_key, agg=agg)
1145+
1146+
lp = self.get_logical_plan(query)
1147+
pp = self.logical_to_physical(lp, push_sql=True,
1148+
push_sql_grouping=True)
1149+
1150+
self.assertEquals(self.get_count(pp, MyriaQueryScan), 1)
1151+
1152+
for op in pp.walk():
1153+
if isinstance(op, MyriaQueryScan):
1154+
self.assertTrue(re.search(expected, op.sql))
1155+
1156+
def test_aggregate_AVG_pushed(self):
1157+
"""AVG is translated properly for postgresql. This is
1158+
a function not in SQLAlchemy"""
1159+
self._check_aggregate_functions_pushed(
1160+
aggregate.AVG.__name__, 'avg')
1161+
1162+
def test_aggregate_STDDEV_pushed(self):
1163+
"""STDEV is translated properly for postgresql. This is
1164+
a function that is named differently in Raco and postgresql"""
1165+
self._check_aggregate_functions_pushed(
1166+
aggregate.STDEV.__name__, 'stddev_samp')
1167+
1168+
def test_aggregate_COUNTALL_pushed(self):
1169+
"""COUNTALL is translated properly for postgresql. This is
1170+
a function that is expressed differently in Raco and postgresql"""
1171+
1172+
# MyriaL parses count(*) to Raco COUNTALL. And COUNTALL
1173+
# should currently (under the no nulls semantics of Raco/Myria)
1174+
# translate to COUNT(something)
1175+
self._check_aggregate_functions_pushed(
1176+
'count(*)', r'count[(][a-zA-Z.]+[)]', True)

0 commit comments

Comments
 (0)