Skip to content

Commit e39616f

Browse files
authored
Allow function reregistration (#161)
1 parent 20518e2 commit e39616f

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

dask_sql/context.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def register_function(
224224
name: str,
225225
parameters: List[Tuple[str, type]],
226226
return_type: type,
227+
replace: bool = False,
227228
):
228229
"""
229230
Register a custom function with the given name.
@@ -267,6 +268,7 @@ def f(x):
267268
parameters (:obj:`List[Tuple[str, type]]`): A list ot tuples of parameter name and parameter type.
268269
Use `numpy dtypes <https://numpy.org/doc/stable/reference/arrays.dtypes.html>`_ if possible.
269270
return_type (:obj:`type`): The return type of the function
271+
replace (:obj:`bool`): Do not raise an error if the function is already present
270272
271273
See also:
272274
:func:`register_aggregation`
@@ -277,7 +279,7 @@ def f(x):
277279
)
278280

279281
name = name.lower()
280-
if name in self.functions:
282+
if not replace and name in self.functions:
281283
if self.functions[name] != f:
282284
raise ValueError(
283285
"Registering different functions with the same name is not allowed"
@@ -290,6 +292,7 @@ def register_aggregation(
290292
name: str,
291293
parameters: List[Tuple[str, type]],
292294
return_type: type,
295+
replace: bool = False,
293296
):
294297
"""
295298
Register a custom aggregation with the given name.
@@ -333,6 +336,7 @@ def register_aggregation(
333336
parameters (:obj:`List[Tuple[str, type]]`): A list ot tuples of parameter name and parameter type.
334337
Use `numpy dtypes <https://numpy.org/doc/stable/reference/arrays.dtypes.html>`_ if possible.
335338
return_type (:obj:`type`): The return type of the function
339+
replace (:obj:`bool`): Do not raise an error if the function is already present
336340
337341
See also:
338342
:func:`register_function`
@@ -343,7 +347,7 @@ def register_aggregation(
343347
)
344348

345349
name = name.lower()
346-
if name in self.functions:
350+
if not replace and name in self.functions:
347351
if self.functions[name] != f:
348352
raise ValueError(
349353
"Registering different functions with the same name is not allowed"

tests/integration/test_function.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def f(x):
6969
with pytest.raises(ValueError):
7070
c.register_function(f, "f", [("x", np.float64)], np.float64)
7171

72+
# only if we replace it
73+
c.register_function(f, "f", [("x", np.float64)], np.float64, replace=True)
74+
7275
fagg = dd.Aggregation("f", lambda x: x.sum(), lambda x: x.sum())
7376
c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64)
7477
c.register_aggregation(fagg, "fagg", [("x", np.int64)], np.int64)
@@ -77,3 +80,5 @@ def f(x):
7780

7881
with pytest.raises(ValueError):
7982
c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64)
83+
84+
c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64, replace=True)

0 commit comments

Comments
 (0)