Skip to content

Commit de3d5c5

Browse files
committed
add an abstraction for translating function names that SQLAlchemy doesn't know about
1 parent dda09ee commit de3d5c5

File tree

3 files changed

+89
-7
lines changed

3 files changed

+89
-7
lines changed

raco/backends/myria/myria.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from raco.algebra import Shuffle
1212
from raco.algebra import convertcondition
1313
from raco.backends import Language, Algebra
14-
from raco.backends.sql.catalog import SQLCatalog
14+
from raco.backends.sql.catalog import SQLCatalog, PostgresSQLFunctionProvider
1515
from raco.catalog import Catalog
1616
from raco.datastructure.UnionFind import UnionFind
1717
from raco.expression import UnnamedAttributeRef
@@ -1478,7 +1478,8 @@ def __init__(self, dialect=None, push_grouping=False):
14781478
def fire(self, expr):
14791479
if isinstance(expr, (algebra.Scan, algebra.ScanTemp)):
14801480
return expr
1481-
cat = SQLCatalog(push_grouping=self.push_grouping)
1481+
cat = SQLCatalog(provider=PostgresSQLFunctionProvider(),
1482+
push_grouping=self.push_grouping)
14821483
try:
14831484
sql_plan = cat.get_sql(expr)
14841485
sql_string = sql_plan.compile(dialect=self.dialect)

raco/backends/sql/catalog.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import raco.scheme as scheme
1313
import raco.types as types
1414
from raco.representation import RepresentationProperties
15+
import abc
1516

1617

1718
type_to_raco = {Integer: types.LONG_TYPE,
@@ -30,10 +31,44 @@
3031
types.DATETIME_TYPE: DateTime}
3132

3233

34+
class SQLFunctionProvider(object):
35+
"""Interface for translating function names. For Raco functions
36+
not understood by SQLAlchemy, like stddev, we cannot rely
37+
on SQLAlchemy's compiler to translate function
38+
names to the given dialect.
39+
For functions not understood by SQLAlchemy, it just emits them as
40+
given."""
41+
42+
@abc.abstractmethod
43+
def convert_unary_expr(self, expr, input):
44+
pass
45+
46+
47+
class _DefaultSQLFunctionProvider(SQLFunctionProvider):
48+
def convert_unary_expr(self, expr, input):
49+
# just use the function name without complaining
50+
fname = expr.__class__.__name__.lower()
51+
return getattr(func, fname)(input)
52+
53+
54+
class PostgresSQLFunctionProvider(SQLFunctionProvider):
55+
def convert_unary_expr(self, expr, input):
56+
fname = expr.__class__.__name__.lower()
57+
58+
# replacements
59+
if fname == "stdev":
60+
return func.stddev_samp(input)
61+
62+
# Warning: may create some functions not available in Postgres
63+
return getattr(func, fname)(input)
64+
65+
3366
class SQLCatalog(Catalog):
34-
def __init__(self, engine=None, push_grouping=False):
67+
def __init__(self, engine=None, push_grouping=False,
68+
provider=None):
3569
self.engine = engine
3670
self.push_grouping = push_grouping
71+
self.provider = provider or _DefaultSQLFunctionProvider()
3772
self.metadata = MetaData()
3873

3974
@staticmethod
@@ -108,10 +143,9 @@ def _convert_zeroary_expr(self, cols, expr, input_scheme):
108143
def _convert_unary_expr(self, cols, expr, input_scheme):
109144
input = self._convert_expr(cols, expr.input, input_scheme)
110145

111-
fname = expr.__class__.__name__.lower()
112-
# if SQL has a supported function by this name
113-
if hasattr(func, fname):
114-
return getattr(func, fname)(input)
146+
c = self.provider.convert_unary_expr(expr, input)
147+
if c is not None:
148+
return c
115149

116150
raise NotImplementedError("expression {} to sql".format(type(expr)))
117151

raco/myrial/optimizer_tests.py

+47
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import collections
22
import random
33
import sys
4+
import re
45

56
from raco.algebra import *
67
from raco.expression import NamedAttributeRef as AttRef
78
from raco.expression import UnnamedAttributeRef as AttIndex
89
from raco.expression import StateVar
10+
from raco.expression import aggregate
911

1012
from raco.backends.myria import (
1113
MyriaShuffleConsumer, MyriaShuffleProducer, MyriaHyperShuffleProducer,
@@ -1125,3 +1127,48 @@ def test_push_half_groupby_into_sql(self):
11251127
expected = dict(((k, v), 1) for k, v in temp.items())
11261128

11271129
self.assertEquals(result, expected)
1130+
1131+
def _check_aggregate_functions_pushed(
1132+
self,
1133+
func,
1134+
expected,
1135+
override=False):
1136+
if override:
1137+
agg = func
1138+
else:
1139+
agg = "{func}(r.i)".format(func=func)
1140+
1141+
query = """
1142+
r = scan({part});
1143+
t = select r.h, {agg} from r;
1144+
store(t, OUTPUT);""".format(part=self.part_key, agg=agg)
1145+
print query
1146+
1147+
lp = self.get_logical_plan(query)
1148+
pp = self.logical_to_physical(lp, push_sql=True,
1149+
push_sql_grouping=True)
1150+
1151+
self.assertEquals(self.get_count(pp, MyriaQueryScan), 1)
1152+
1153+
for op in pp.walk():
1154+
if isinstance(op, MyriaQueryScan):
1155+
print op.sql
1156+
self.assertTrue(re.search(expected, op.sql))
1157+
1158+
def test_aggregate_AVG_pushed(self):
1159+
"""AVG is translated properly for postgresql. This is
1160+
a function not in SQLAlchemy"""
1161+
self._check_aggregate_functions_pushed(
1162+
aggregate.AVG.__name__, 'avg')
1163+
1164+
def test_aggregate_STDDEV_pushed(self):
1165+
"""STDEV is translated properly for postgresql. This is
1166+
a function that is named differently in Raco and postgresql"""
1167+
self._check_aggregate_functions_pushed(
1168+
aggregate.STDEV.__name__, 'stddev_samp')
1169+
1170+
def test_aggregate_COUNTALL_pushed(self):
1171+
"""COUNTALL is translated properly for postgresql. This is
1172+
a function that is expressed differently in Raco and postgresql"""
1173+
self._check_aggregate_functions_pushed(
1174+
'count(*)', r'count[(][a-zA-Z.]+[)]', True)

0 commit comments

Comments
 (0)