Skip to content

Commit 37d4f27

Browse files
authored
Merge pull request #435 from db434/multiple_return_values
Allow dimension-checking of multiple return values
2 parents b61784a + a7ed932 commit 37d4f27

File tree

3 files changed

+135
-10
lines changed

3 files changed

+135
-10
lines changed

docs/usage.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,21 @@ using the :meth:`@accepts <unyt.dimensions.accepts>` and :meth:`@returns <unyt.d
325325
... def foo(a, v):
326326
... return a * v
327327
...
328-
>>> res = foo(a= 2 * u.s, v = 3 * u.m/u.s)
328+
>>> res = foo(a=2*u.s, v=3*u.m/u.s)
329329
>>> print(res)
330330
6 m
331331

332+
:meth:`@accepts <unyt.dimensions.accepts>` can specify the dimensions of any subset of inputs and :meth:`@returns <unyt.dimensions.returns>` must always describe all outputs.
333+
334+
>>> @returns(length, length/time**2)
335+
... @accepts(v=length/time)
336+
... def bar(a, v):
337+
... return a * v, v / a
338+
...
339+
>>> res = bar(a=2*u.s, v=3*u.m/u.s)
340+
>>> print(*res)
341+
6 m 1.5 m/s**2
342+
332343
.. note::
333344
Using these decorators may incur some performance overhead, especially for small arrays.
334345

unyt/dimensions.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from sympy import Rational, Symbol, sympify
1212

13+
from unyt._deprecation import warn_deprecated
14+
1315
#: mass
1416
mass = Symbol("(mass)", positive=True)
1517
#: length
@@ -290,14 +292,16 @@ def new_f(*args, **kwargs):
290292
return check_accepts
291293

292294

293-
def returns(r_unit):
295+
def returns(*r_units, r_unit=None):
294296
"""Decorator for checking function return units.
295297
296298
Parameters
297299
----------
298-
r_unit: :py:class:`sympy.core.symbol.Symbol`
300+
*r_units: :py:class:`sympy.core.symbol.Symbol`
299301
SI base unit (or combination of units), eg. length/time
300-
of the value returned by the original function
302+
of the value(s) returned by the original function
303+
r_unit: :py:class:`sympy.core.symbol.Symbol`
304+
Deprecated version of `r_units` which supports only one named return value.
301305
302306
Examples
303307
--------
@@ -318,9 +322,29 @@ def returns(r_unit):
318322
Traceback (most recent call last):
319323
...
320324
TypeError: result '6 m' does not match (length)/(time)
321-
325+
>>> @returns(length, length/time**2)
326+
... def f(a, v):
327+
... return a * v, v / a
328+
...
329+
>>> res = f(a= 2 * u.s, v = 3 * u.m/u.s)
330+
>>> print(*res)
331+
6 m 1.5 m/s**2
322332
"""
323333

334+
# Convert deprecated arguments into current ones where possible.
335+
if r_unit is not None:
336+
if len(r_units) > 0:
337+
raise ValueError(
338+
"Cannot specify `r_unit` and other return values simultaneously"
339+
)
340+
else:
341+
warn_deprecated(
342+
"@unyt.returns(r_unit=...)",
343+
replacement="use @unyt.returns(...)",
344+
since_version="3.0",
345+
)
346+
r_units = (r_unit,)
347+
324348
def check_returns(f):
325349
"""Decorates original function.
326350
@@ -338,18 +362,27 @@ def check_returns(f):
338362

339363
@wraps(f)
340364
def new_f(*args, **kwargs):
341-
"""The decorated function, which checks the return unit.
365+
"""The decorated function, which checks the return units.
342366
343367
Raises
344368
------
345369
TypeError
346370
If the units do not match.
347371
348372
"""
349-
result = f(*args, **kwargs)
350-
if not _has_dimensions(result, r_unit):
351-
raise TypeError(f"result '{result}' does not match {r_unit}")
352-
return result
373+
results = f(*args, **kwargs)
374+
375+
# Make results a tuple so we can treat single and multiple return values the
376+
# same way.
377+
if isinstance(results, tuple):
378+
result_tuple = results
379+
else:
380+
result_tuple = (results,)
381+
382+
for result, dimension in zip(result_tuple, r_units):
383+
if not _has_dimensions(result, dimension):
384+
raise TypeError(f"result '{result}' does not match {dimension}")
385+
return results
353386

354387
return new_f
355388

unyt/tests/test_unyt_testing.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"""
55
import pytest
66

7+
from unyt import accepts, meter, returns, second
78
from unyt.array import unyt_array, unyt_quantity
9+
from unyt.dimensions import length, time
810
from unyt.testing import assert_allclose_units
911

1012

@@ -40,3 +42,82 @@ def test_atol_conversion_error():
4042
a2 = unyt_array([1.0, 2.0, 3.0], "cm")
4143
with pytest.raises(AssertionError):
4244
assert_allclose_units(a1, a2, atol=unyt_quantity(0.0, "kg"))
45+
46+
47+
def test_accepts():
48+
@accepts(a=time, v=length / time)
49+
def foo(a, v):
50+
return a * v
51+
52+
foo(a=2 * second, v=3 * meter / second)
53+
54+
with pytest.raises(TypeError):
55+
foo(a=2 * meter, v=3 * meter / second)
56+
57+
with pytest.raises(TypeError):
58+
foo(a=2 * second, v=3 * meter)
59+
60+
61+
def test_accepts_partial():
62+
@accepts(a=time)
63+
def bar(a, v):
64+
return a * v
65+
66+
bar(a=2 * second, v=3 * meter / second)
67+
bar(a=2 * second, v=3 * meter)
68+
69+
with pytest.raises(TypeError):
70+
bar(a=2 * meter, v=3 * meter / second)
71+
72+
@accepts(v=length / time)
73+
def baz(a, v):
74+
return a * v
75+
76+
baz(a=2 * second, v=3 * meter / second)
77+
baz(a=2 * meter, v=3 * meter / second)
78+
79+
with pytest.raises(TypeError):
80+
baz(a=2 * second, v=3 * meter)
81+
82+
83+
def test_returns():
84+
@returns(length)
85+
def foo(a, v):
86+
return a * v
87+
88+
# This usage is deprecated, but we still want to support it for now.
89+
with pytest.deprecated_call():
90+
91+
@returns(r_unit=length)
92+
def bar(a, v):
93+
return a * v
94+
95+
for func in [foo, bar]:
96+
func(a=2 * second, v=3 * meter / second)
97+
98+
with pytest.raises(TypeError):
99+
func(a=2 * meter, v=3 * meter / second)
100+
101+
with pytest.raises(TypeError):
102+
func(a=2 * second, v=3 * meter)
103+
104+
# We don't support a mixture of the two usage styles.
105+
with pytest.raises(ValueError):
106+
107+
@returns(length, r_unit=time)
108+
def _(a, v):
109+
return a, v
110+
111+
112+
def test_returns_multiple():
113+
@returns(time, length / time)
114+
def baz(a, v):
115+
return a, v
116+
117+
baz(a=2 * second, v=3 * meter / second)
118+
119+
with pytest.raises(TypeError):
120+
baz(a=2 * meter, v=3 * meter / second)
121+
122+
with pytest.raises(TypeError):
123+
baz(a=2 * second, v=3 * meter)

0 commit comments

Comments
 (0)