Fitting you flax models made simple! Thereby the whole fitting function can be used in a jit context.
pip install git+https://github.com/JeyRunner/flaxfit.gitFor full examples see examples/ folder.
from flax import nnx
# ...
rngs = nnx.Rngs(0)
model = nnx.Sequential(
nnx.Linear(in_features=1, out_features=10, rngs=rngs),
nnx.relu,
nnx.Linear(in_features=10, out_features=1, rngs=rngs),
)
def loss(predictions_y, dataset: Dataset):
return dict(
mse=jnp.mean((predictions_y - dataset.y)**2)
)
# epoch callback (executed on host)
def callback(epoch: int, metrics: dict):
print(f'> epoch {epoch} - {metrics}')
# dataset
x = jnp.arange(20)[:, jnp.newaxis]
dataset = DatasetXY(
x=x,
y=x**2
)
# setup fitter
fitter = FlaxModelFitter(
loss_function=loss,
update_batch_size=5
)
# fit
train_state = fitter.create_train_state(model)
train_state, history = fitter.train_fit(
train_state,
dataset=dataset,
dataset_eval=dataset_eval,
evaluate_each_n_epochs=1,
epoch_callback_fn=callback,
num_epochs=200
)
print(history)Install deps:
pip install .[dev, test]