@@ -56,14 +56,25 @@ def fn_const(x):
56
56
("jax" , "float64" , "float32" ),
57
57
]:
58
58
continue
59
+
59
60
integrator_name = type (integrator ).__name__
61
+
60
62
# VEGAS supports only numpy and torch
61
63
if integrator_name == "VEGAS" and backend in ["jax" , "tensorflow" ]:
62
64
continue
63
65
64
66
# Set the global precision
65
67
set_precision (dtype_global , backend = backend )
66
68
69
+ # Determine expected dtype
70
+ if backend == "tensorflow" :
71
+ import tensorflow as tf
72
+
73
+ expected_dtype_name = dtype_arg if dtype_arg else tf .keras .backend .floatx ()
74
+ else :
75
+ expected_dtype_name = dtype_arg if dtype_arg else dtype_global
76
+
77
+ # Set integration domain
67
78
integration_domain = [[0.0 , 1.0 ], [- 2.0 , 0.0 ]]
68
79
if dtype_arg is not None :
69
80
# Set the integration_domain dtype which should have higher priority
@@ -75,18 +86,18 @@ def fn_const(x):
75
86
)
76
87
assert infer_backend (integration_domain ) == backend
77
88
assert get_dtype_name (integration_domain ) == dtype_arg
78
- expected_dtype_name = dtype_arg
79
- else :
80
- expected_dtype_name = dtype_global
81
89
82
90
print (
83
- f"[2mTesting { integrator_name } with { backend } , argument dtype"
84
- f" { dtype_arg } , global/default dtype { dtype_global } [m "
91
+ f"Testing { integrator_name } with { backend } , argument dtype"
92
+ f" { dtype_arg } , global/default dtype { dtype_global } "
85
93
)
94
+
95
+ # Integration
86
96
if integrator_name in ["MonteCarlo" , "VEGAS" ]:
87
97
extra_kwargs = {"seed" : 0 }
88
98
else :
89
99
extra_kwargs = {}
100
+
90
101
result = integrator .integrate (
91
102
fn = fn_const ,
92
103
dim = 2 ,
@@ -95,8 +106,12 @@ def fn_const(x):
95
106
backend = backend ,
96
107
** extra_kwargs ,
97
108
)
109
+
98
110
assert infer_backend (result ) == backend
99
- assert get_dtype_name (result ) == expected_dtype_name
111
+ assert (
112
+ get_dtype_name (result ) == expected_dtype_name
113
+ ), f"Expected dtype { expected_dtype_name } , got { get_dtype_name (result )} "
114
+
100
115
# VEGAS seems to be bad at integrating constant functions currently
101
116
max_error = 0.03 if integrator_name == "VEGAS" else 1e-5
102
117
assert anp .abs (result - (- 4.0 )) < max_error
0 commit comments