-
Notifications
You must be signed in to change notification settings - Fork 721
Description
not able to use lbfgs as adam in flax nnx, need a example how to use
File "/Users/jlperla/Documents/GitHub/ECON622_instructor/lectures/examples/linear_regression_jax_nnx.py", line 52, in
loss = train_step(model, optimizer, X_batch, Y_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/graph.py", line 1043, in update_context_manager_wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/transforms/transforms.py", line 359, in jit_wrapper
out, output_state, output_graphdef = jitted_fn(
^^^^^^^^^^
File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/transforms/transforms.py", line 158, in jit_fn
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/Users/jlperla/Documents/GitHub/ECON622_instructor/lectures/examples/linear_regression_jax_nnx.py", line 43, in train_step
optimizer.update(grads)
File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/training/optimizer.py", line 201, in update
updates, new_opt_state = self.tx.update(grads, self.opt_state, state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/optax/transforms/_combining.py", line 73, in update_fn
updates, new_s = fn(updates, s, params, **extra_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: scale_by_zoom_linesearch..update_fn() missing 3 required keyword-only arguments: 'value', 'grad', and 'value_fn'