use case of multi-GPU sharding, nnx.jit, save/load and performances #4575
Unanswered
jecampagne
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I have a toy-example of a model to use Yong Song denoising sampling, that I am writing in different conditions to experience the FLAX NXX & Obax libs. To discuss & ask few questions I've setup this notebook on Colab just to read the code.
What about tthe necessity or not of the
key_scorenetkey?train_stepwhere I do not figure if I have done correctlyI observe that the sharding of
perturbed_xis the same asxthe data, butrandom_twhich isthe second argument or the model call, it looks different : horizontal "GPU0" seperated to "GPU1", so I wander
if it is correct.
lossis a scalar, so I do not know if it is the loss of the mean on all the models ???
as I have to use succesively the two state_restored statement to get loaded the model???
After the sampling looks ok.
Thanks for your attention.
Beta Was this translation helpful? Give feedback.
All reactions