Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dimension-checking of multiple return values #435

Merged
merged 9 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,21 @@ using the :meth:`@accepts <unyt.dimensions.accepts>` and :meth:`@returns <unyt.d
... def foo(a, v):
... return a * v
...
>>> res = foo(a= 2 * u.s, v = 3 * u.m/u.s)
>>> res = foo(a=2*u.s, v=3*u.m/u.s)
>>> print(res)
6 m

:meth:`@accepts <unyt.dimensions.accepts>` can specify the dimensions of any subset of inputs and :meth:`@returns <unyt.dimensions.returns>` must always describe all outputs.

>>> @returns(length, length/time**2)
... @accepts(v=length/time)
... def bar(a, v):
... return a * v, v / a
...
>>> res = bar(a=2*u.s, v=3*u.m/u.s)
>>> print(*res)
6 m 1.5 m/s**2

.. note::
Using these decorators may incur some performance overhead, especially for small arrays.

Expand Down
51 changes: 42 additions & 9 deletions unyt/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""


import warnings
from functools import wraps
from itertools import chain

Expand Down Expand Up @@ -290,14 +291,16 @@ def new_f(*args, **kwargs):
return check_accepts


def returns(r_unit):
def returns(*r_units, r_unit=None):
"""Decorator for checking function return units.

Parameters
----------
r_unit: :py:class:`sympy.core.symbol.Symbol`
*r_units: :py:class:`sympy.core.symbol.Symbol`
SI base unit (or combination of units), eg. length/time
of the value returned by the original function
of the value(s) returned by the original function
r_unit: :py:class:`sympy.core.symbol.Symbol`
Deprecated version of `r_units` which supports only one named return value.

Examples
--------
Expand All @@ -318,9 +321,30 @@ def returns(r_unit):
Traceback (most recent call last):
...
TypeError: result '6 m' does not match (length)/(time)

>>> @returns(length, length/time**2)
... def f(a, v):
... return a * v, v / a
...
>>> res = f(a= 2 * u.s, v = 3 * u.m/u.s)
>>> print(*res)
6 m 1.5 m/s**2
"""

# Convert deprecated arguments into current ones where possible.
if r_unit is not None:
if len(r_units) > 0:
raise ValueError(
"Cannot specify `r_unit` and other return values simultaneously"
)
else:
warnings.warn(
"Use of the @returns(r_unit=...) syntax is deprecated. "
"Please use @returns(...) instead.",
category=DeprecationWarning,
stacklevel=2,
)
db434 marked this conversation as resolved.
Show resolved Hide resolved
r_units = (r_unit,)

def check_returns(f):
"""Decorates original function.

Expand All @@ -338,18 +362,27 @@ def check_returns(f):

@wraps(f)
def new_f(*args, **kwargs):
"""The decorated function, which checks the return unit.
"""The decorated function, which checks the return units.

Raises
------
TypeError
If the units do not match.

"""
result = f(*args, **kwargs)
if not _has_dimensions(result, r_unit):
raise TypeError(f"result '{result}' does not match {r_unit}")
return result
results = f(*args, **kwargs)

# Make results a tuple so we can treat single and multiple return values the
# same way.
if isinstance(results, tuple):
result_tuple = results
else:
result_tuple = (results,)

for result, dimension in zip(result_tuple, r_units):
if not _has_dimensions(result, dimension):
raise TypeError(f"result '{result}' does not match {dimension}")
return results

return new_f

Expand Down
81 changes: 81 additions & 0 deletions unyt/tests/test_unyt_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""
import pytest

from unyt import accepts, meter, returns, second
from unyt.array import unyt_array, unyt_quantity
from unyt.dimensions import length, time
from unyt.testing import assert_allclose_units


Expand Down Expand Up @@ -40,3 +42,82 @@ def test_atol_conversion_error():
a2 = unyt_array([1.0, 2.0, 3.0], "cm")
with pytest.raises(AssertionError):
assert_allclose_units(a1, a2, atol=unyt_quantity(0.0, "kg"))


def test_accepts():
@accepts(a=time, v=length / time)
def foo(a, v):
return a * v

foo(a=2 * second, v=3 * meter / second)

with pytest.raises(TypeError):
foo(a=2 * meter, v=3 * meter / second)

with pytest.raises(TypeError):
foo(a=2 * second, v=3 * meter)


def test_accepts_partial():
@accepts(a=time)
def bar(a, v):
return a * v

bar(a=2 * second, v=3 * meter / second)
bar(a=2 * second, v=3 * meter)

with pytest.raises(TypeError):
bar(a=2 * meter, v=3 * meter / second)

@accepts(v=length / time)
def baz(a, v):
return a * v

baz(a=2 * second, v=3 * meter / second)
baz(a=2 * meter, v=3 * meter / second)

with pytest.raises(TypeError):
baz(a=2 * second, v=3 * meter)


def test_returns():
@returns(length)
def foo(a, v):
return a * v

# This usage is deprecated, but we still want to support it for now.
with pytest.deprecated_call():

@returns(r_unit=length)
def bar(a, v):
return a * v

for func in [foo, bar]:
func(a=2 * second, v=3 * meter / second)

with pytest.raises(TypeError):
func(a=2 * meter, v=3 * meter / second)

with pytest.raises(TypeError):
func(a=2 * second, v=3 * meter)

# We don't support a mixture of the two usage styles.
with pytest.raises(ValueError):

@returns(length, r_unit=time)
def _(a, v):
return a, v


def test_returns_multiple():
@returns(time, length / time)
def baz(a, v):
return a, v

baz(a=2 * second, v=3 * meter / second)

with pytest.raises(TypeError):
baz(a=2 * meter, v=3 * meter / second)

with pytest.raises(TypeError):
baz(a=2 * second, v=3 * meter)