Skip to content

Commit 7d67ff5

Browse files
authored
test: update test effected by the optax==0.2.7 release (#2137)
* test: update test effected by the `optax==0.2.7` release * test: remove test skip
1 parent 6751588 commit 7d67ff5

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

test/test_optimizers.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def test_optim_multi_params(optim_class, args, kwargs, uses_value_arg):
100100
+ optax_optimizers,
101101
)
102102
@pytest.mark.filterwarnings("ignore:.*tree_multimap:FutureWarning")
103-
@pytest.mark.skip(reason="Failing on optax==0.2.7. Halting till #2137 is merged.")
104103
def test_numpyrooptim_no_double_jit(optim_class, args, kwargs, uses_value_arg):
105104
opt = optim_class(*args, **kwargs)
106105
if not isinstance(opt, optim._NumPyroOptim):
@@ -124,8 +123,4 @@ def my_fn(state, g):
124123
state = my_fn(state, jnp.ones(10) * 2.0)
125124
state = my_fn(state, jnp.ones(10) * 3.0)
126125

127-
if uses_value_arg:
128-
# Dtype is different on the first call vs the rest of the calls
129-
assert my_fn_calls == 2
130-
else:
131-
assert my_fn_calls == 1
126+
assert my_fn_calls == 1

0 commit comments

Comments
 (0)