You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I recently started to use JAX and implemented a simple linear regression code using a 3-layer MLP. Following "Performance Considerations", I tried various methods and found that using a functional training loop (jax.jit and nnx.split/merge) yields the fastest performance. (114 seconds for 1M iterations on an RTX 3080 GPU)
Then I tried using nnx.fori_loop and got an interesting result--It is much faster than running a Python for loop with a JAX-complied function (31 seconds for 1M iterations on an RTX 3080 GPU).
Question:
Is it common practice to use nnx.fori_loop for faster training? or might there be other factors contributing to this performance boost?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I recently started to use JAX and implemented a simple linear regression code using a 3-layer MLP. Following "Performance Considerations", I tried various methods and found that using a functional training loop (
jax.jitandnnx.split/merge) yields the fastest performance. (114 seconds for 1M iterations on an RTX 3080 GPU)Then I tried using
nnx.fori_loopand got an interesting result--It is much faster than running a Python for loop with a JAX-complied function (31 seconds for 1M iterations on an RTX 3080 GPU).Question:
Is it common practice to use
nnx.fori_loopfor faster training? or might there be other factors contributing to this performance boost?Code
Using Python for loop:
Using nnx.fori_loop:
Beta Was this translation helpful? Give feedback.
All reactions