Skip to content

Conversation

@samanklesaria
Copy link
Collaborator

What does this PR do?

This PR adds a guide that shows some common techniques for working with Flax models during optimization. These include:

  • Calculation of Exponential Moving Averages
  • Optimizing only a low rank addition to certain weights (LORA)
  • Using different learning rates for different parameters to implement the maximal update parameterization
  • Using second order optimizers like LBFGS.
  • Specifying sharding for optimization state that differs from that of parameter state
  • Gradient accumulation

This is very much a work in progress: the guide will be much further fleshed out over time. This draft PR exists only to give a window into the type of tasks that will eventually be covered.

This document emphasizes a style as close to pure jax as possible: to that end, it shows how the flax version of each technique only requires minor deviation from the often more intuitive pure-jax version.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@samanklesaria samanklesaria force-pushed the opt_cookbook branch 2 times, most recently from 3ca3395 to c495dc1 Compare December 1, 2025 23:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant