Skip to content

Commit e48d9c1

Browse files
Require UDF return type and update docs (#283)
Require UDF return type and update docs
1 parent 0bf65bb commit e48d9c1

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

dask_sql/datacontainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,20 @@ def __init__(self, func, row_udf: bool, return_type=None):
195195
"""
196196
self.row_udf = row_udf
197197
self.func = func
198+
199+
if return_type is None:
200+
# These UDFs go through apply and without providing
201+
# a return type, dask will attempt to guess it, and
202+
# dask might be wrong.
203+
raise ValueError("Return type must be provided")
198204
self.meta = (None, return_type)
199205

200206
def __call__(self, *args, **kwargs):
201207
if self.row_udf:
202208
df = args[0].to_frame()
203209
for operand in args[1:]:
204210
df[operand.name] = operand
205-
result = df.apply(self.func, axis=1, meta=self.meta)
211+
result = df.apply(self.func, axis=1, meta=self.meta).astype(self.meta[1])
206212
else:
207213
result = self.func(*args, **kwargs)
208214
return result

docs/pages/custom.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ These functions may be registered as above and flagged as row UDFs using the `ro
4343
def f(row):
4444
return row['a'] + row['b']
4545
46-
# todo - fix the api
47-
c.register_function(f, "f", [], None, row_udf=True)
46+
c.register_function(f, "f", [("a", np.int64), ("b", np.int64)], np.int64, row_udf=True)
4847
c.sql("SELECT f(a, b) FROM data")
4948
5049
** Note: Row UDFs use `apply` which may have unpredictable performance characteristics, depending on the function and dataframe library **

tests/integration/test_function.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,31 @@ def f(row):
3838
assert_frame_equal(return_df.reset_index(drop=True), df[["a"]] ** 2)
3939

4040

41+
@pytest.mark.parametrize(
42+
"retty",
43+
[None, np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_],
44+
)
45+
def test_custom_function_row_return_types(c, df, retty):
46+
def f(row):
47+
return row["a"] ** 2
48+
49+
if retty is None:
50+
with pytest.raises(ValueError):
51+
c.register_function(f, "f", [("x", np.float64)], retty, row_udf=True)
52+
return
53+
54+
c.register_function(f, "f", [("x", np.float64)], retty, row_udf=True)
55+
return_df = c.sql(
56+
"""
57+
SELECT F(a) AS a
58+
FROM df
59+
"""
60+
)
61+
return_df = return_df.compute()
62+
expectation = (df[["a"]] ** 2).astype(retty)
63+
assert_frame_equal(return_df.reset_index(drop=True), expectation)
64+
65+
4166
def test_multiple_definitions(c, df_simple):
4267
def f(x):
4368
return x ** 2

0 commit comments

Comments
 (0)