Replies: 1 comment 2 replies
-
|
Just double checking, you're not changing the batch size within the same training loop right? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
When working on a current project I found that the compiling time of my
train_stepfunction increased drastically when I increased the batch size I was using. At first it was taking about 20s when using a batch size of 2 but increased up to 4000s when I bumped the batch size up to 64. Here is thetrain_stepfunction I'm using (it may be a little confusing because I'm working on an IQA task with a custom model that has an state):As a note, I have a different function to calculate the metrics during validation, and this function isn't showing the same behavior so I thought that it may have been related to the calculation of the gradient, but I don't really know if it makes sense.
I was under the assumption that changing the batch size shouldn't have this big of an influence in compilation and, as I couldn't narrow down the problem, I tried to replicate it in a very simple MNIST classifier example in Colab (here).
What I found was basically the same, as the compilation time goes up with the batch size as you can see in this quick wandb dashboard I set up for the experiment: https://wandb.ai/jorgvt/JaX_Compile?workspace=user-jorgvt
I'd be more than willing to share more information with anyone that can shed some light!
Beta Was this translation helpful? Give feedback.
All reactions