Skip to content

Commit

Permalink
ConstantValue: Support general dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jun 6, 2024
1 parent ea144b0 commit b797746
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
7 changes: 7 additions & 0 deletions test/test_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ufl import PermutationSymbol, as_matrix, as_vector, indices, product
from ufl.classes import Indexed
from ufl.constantvalue import ComplexValue, FloatValue, IntValue, Zero, as_ufl
import numpy


def test_zero(self):
Expand All @@ -29,13 +30,15 @@ def test_float(self):
f4 = FloatValue(1.0)
f5 = 3 - FloatValue(1) - 1
f6 = 3 * FloatValue(2) / 6
f7 = as_ufl(numpy.ones((1,), dtype="d")[0])

assert f1 == f1
self.assertNotEqual(f1, f2) # IntValue vs FloatValue, == compares representations!
assert f2 == f3
assert f2 == f4
assert f2 == f5
assert f2 == f6
assert f2 == f7


def test_int(self):
Expand All @@ -45,13 +48,15 @@ def test_int(self):
f4 = IntValue(1.0)
f5 = 3 - IntValue(1) - 1
f6 = 3 * IntValue(2) / 6
f7 = as_ufl(numpy.ones((1,), dtype="int")[0])

assert f1 == f1
self.assertNotEqual(f1, f2) # IntValue vs FloatValue, == compares representations!
assert f1 == f3
assert f1 == f4
assert f1 == f5
assert f2 == f6 # Division produces a FloatValue
assert f1 == f7


def test_complex(self):
Expand All @@ -62,6 +67,7 @@ def test_complex(self):
f5 = ComplexValue(1.0 + 1.0j)
f6 = as_ufl(1.0)
f7 = as_ufl(1.0j)
f8 = as_ufl(numpy.array([1+1j], dtype="complex")[0])

assert f1 == f1
assert f1 == f4
Expand All @@ -71,6 +77,7 @@ def test_complex(self):
assert f5 == f2 + f3
assert f4 == f5
assert f6 + f7 == f2 + f3
assert f4 == f8


def test_scalar_sums(self):
Expand Down
11 changes: 6 additions & 5 deletions ufl/constantvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# Modified by Massimiliano Leoni, 2016.

from math import atan2
import numbers

import ufl

Expand Down Expand Up @@ -506,12 +507,12 @@ def as_ufl(expression):
"""Converts expression to an Expr if possible."""
if isinstance(expression, (Expr, ufl.BaseForm)):
return expression
elif isinstance(expression, complex):
return ComplexValue(expression)
elif isinstance(expression, float):
return FloatValue(expression)
elif isinstance(expression, int):
elif isinstance(expression, numbers.Integral):
return IntValue(expression)
elif isinstance(expression, numbers.Real):
return FloatValue(expression)
elif isinstance(expression, numbers.Complex):
return ComplexValue(expression)
else:
raise ValueError(
f"Invalid type conversion: {expression} can not be converted to any UFL type."
Expand Down

0 comments on commit b797746

Please sign in to comment.