You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm using pytorch lightning with deepspeed, but am having trouble training the model with stage 3, even though stage 2 works.
The model is part of a very large codebase and difficult to share, but it is a VAE with an adversarial loss, so nothing too out of the ordinary. I did have to implement a little workaround to get the adversarial loss to work (basically by detaching the graph at the right spot and adding a bit of redundancy so that a single backwards pass would compute the right gradient).
Everything works great with stage 2, but when I run in stage 3 I get the following error:
RuntimeError: The size of tensor a (0) must match the size of tensor b (3) at non-singleton dimension 4
Are there flags one can turn on to get more informative debug information (for example which tensors have a mismatch)?
Also, does stage 3 make assumptions about the shape of the loss? Is the user supposed to avoid aggregation over the batch, for example?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'm using pytorch lightning with deepspeed, but am having trouble training the model with stage 3, even though stage 2 works.
The model is part of a very large codebase and difficult to share, but it is a VAE with an adversarial loss, so nothing too out of the ordinary. I did have to implement a little workaround to get the adversarial loss to work (basically by detaching the graph at the right spot and adding a bit of redundancy so that a single backwards pass would compute the right gradient).
Everything works great with stage 2, but when I run in stage 3 I get the following error:
Are there flags one can turn on to get more informative debug information (for example which tensors have a mismatch)?
Also, does stage 3 make assumptions about the shape of the loss? Is the user supposed to avoid aggregation over the batch, for example?
Beta Was this translation helpful? Give feedback.
All reactions