How to load PyTorch checkpoints into JAX/Flax? #927
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Jan 22, 2021
Replies: 1 comment 6 replies
-
|
Pytorch checkpoints contain a
Often @nikitakit wrote the following code for importing PyTorch BERT checkpoints into a Flax model: https://github.com/nikitakit/flax_bert/blob/master/import_weights.py |
Beta Was this translation helpful? Give feedback.
6 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment



Pytorch checkpoints contain a
state_dictwith all the weights/parameters for the models, and converting it to Flax involves:NCHWdimensions for conv weights.Often
flax.traverse_util.flatten_dictis useful, because you only need to operate on a flat dict instead of a nested dict. Once they align you useunflatten_dictto get the normal form back.@nikitakit wrote the following code for importing PyTorch BERT checkpoints into a Flax model: https://github.com/nikitakit/flax_bert/blob/master/import_weights.py