Replies: 1 comment
-
|
your train_step in non nnx.vmap hasn't .5 factor, and you didn't hint what is the problem when executing your code (just what's blocking you not the code execution). I believe the problem is that using nnx.vmap, will only get 1 item per call thus, the BatchNorm has only 1 element to work with when using mean and variance (variance will be null).. NB: when sharing your code use <> Code option ( with 'python' at the start ) to output your code the right way, also the change with the nnx.vmap option only occurs in train_step (so you can avoid the excess code) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Update of question:
Dear Community,
I encountered a small challenge while using a batch normalization layer with nnx.vmap. To illustrate the issue, I have created a minimal example code snippet. Based on my understanding of the documentation for flax-nnx.vmap the issue seems to stem from the handling of BatchStat, which requires special consideration when using vmap.
Currently, I am struggling to make the final example in the attached code work. Does anyone have suggestions on how to adjust the loss function to work correctly with nnx.vmap?
I have a problem with the first defined train step, which uses nnx.vmap and does not execute.
Thank you very much!
Beta Was this translation helpful? Give feedback.
All reactions