Skip to content

Commit b797746

Browse files
committed
ConstantValue: Support general dtypes
1 parent ea144b0 commit b797746

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

test/test_literals.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ufl import PermutationSymbol, as_matrix, as_vector, indices, product
55
from ufl.classes import Indexed
66
from ufl.constantvalue import ComplexValue, FloatValue, IntValue, Zero, as_ufl
7+
import numpy
78

89

910
def test_zero(self):
@@ -29,13 +30,15 @@ def test_float(self):
2930
f4 = FloatValue(1.0)
3031
f5 = 3 - FloatValue(1) - 1
3132
f6 = 3 * FloatValue(2) / 6
33+
f7 = as_ufl(numpy.ones((1,), dtype="d")[0])
3234

3335
assert f1 == f1
3436
self.assertNotEqual(f1, f2) # IntValue vs FloatValue, == compares representations!
3537
assert f2 == f3
3638
assert f2 == f4
3739
assert f2 == f5
3840
assert f2 == f6
41+
assert f2 == f7
3942

4043

4144
def test_int(self):
@@ -45,13 +48,15 @@ def test_int(self):
4548
f4 = IntValue(1.0)
4649
f5 = 3 - IntValue(1) - 1
4750
f6 = 3 * IntValue(2) / 6
51+
f7 = as_ufl(numpy.ones((1,), dtype="int")[0])
4852

4953
assert f1 == f1
5054
self.assertNotEqual(f1, f2) # IntValue vs FloatValue, == compares representations!
5155
assert f1 == f3
5256
assert f1 == f4
5357
assert f1 == f5
5458
assert f2 == f6 # Division produces a FloatValue
59+
assert f1 == f7
5560

5661

5762
def test_complex(self):
@@ -62,6 +67,7 @@ def test_complex(self):
6267
f5 = ComplexValue(1.0 + 1.0j)
6368
f6 = as_ufl(1.0)
6469
f7 = as_ufl(1.0j)
70+
f8 = as_ufl(numpy.array([1+1j], dtype="complex")[0])
6571

6672
assert f1 == f1
6773
assert f1 == f4
@@ -71,6 +77,7 @@ def test_complex(self):
7177
assert f5 == f2 + f3
7278
assert f4 == f5
7379
assert f6 + f7 == f2 + f3
80+
assert f4 == f8
7481

7582

7683
def test_scalar_sums(self):

ufl/constantvalue.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# Modified by Massimiliano Leoni, 2016.
1111

1212
from math import atan2
13+
import numbers
1314

1415
import ufl
1516

@@ -506,12 +507,12 @@ def as_ufl(expression):
506507
"""Converts expression to an Expr if possible."""
507508
if isinstance(expression, (Expr, ufl.BaseForm)):
508509
return expression
509-
elif isinstance(expression, complex):
510-
return ComplexValue(expression)
511-
elif isinstance(expression, float):
512-
return FloatValue(expression)
513-
elif isinstance(expression, int):
510+
elif isinstance(expression, numbers.Integral):
514511
return IntValue(expression)
512+
elif isinstance(expression, numbers.Real):
513+
return FloatValue(expression)
514+
elif isinstance(expression, numbers.Complex):
515+
return ComplexValue(expression)
515516
else:
516517
raise ValueError(
517518
f"Invalid type conversion: {expression} can not be converted to any UFL type."

0 commit comments

Comments
 (0)