Catastrophic performance loss in 1st epoch #1094
Replies: 8 comments
-
|
So interestingly, doing the exact same thing but with a lower warmup leads to better performance but still very strange performance: |
Beta Was this translation helpful? Give feedback.
-
|
@alexisdrakopoulos don't think this is a bug of any sort. FYI warmup is in steps so if you have a small dataset that could be a lot of epochs, so it means you're LR is way too high and things collapse after the LR hits the breaking point. 1e-4 would be a high fine-tune for AdamW but even higher for LION since I vaguely recall the recommendation for that optimizer being ~1/3 to 1/10 of Adam equivalent? |
Beta Was this translation helpful? Give feedback.
-
|
Moving to discussions... |
Beta Was this translation helpful? Give feedback.
-
|
I have around 3 million image text pairs. The warmup was indeed ridiculous, but I tried many different settings with lower LR/warmups etc.. I can't figure out why the first few epochs my performance massively improves before slowly degrading over subsequent epochs. It's wasting a lot of money and compute. I'm now just using the model from epoch 2 or 3 which is like 30 minutes into the fine tuning process.... Do you have any advice on how to explore this type of behavior and fix it? I am frustrated because I am comparing it to DinoV2 in terms of Image to Image retrieval, and DinoV2 gets these results out of the box: even ResNet fairs well: A little frustrating. |
Beta Was this translation helpful? Give feedback.
-
|
@alexisdrakopoulos can you share any information about the dataset? I think there are two likely causes here (2 being most likely), |
Beta Was this translation helpful? Give feedback.
-
|
Hi @jn2clark I really appreciate you answering! The data shouldn't have too many duplicates, but the webdataset file I use is not pre-shuffled. I use the open_clip shuffling function which is 2 stage I believe. Each webdataset file has 3000 image/text pairs and I have just under 1000 of those. The dataset are images of antiquities, largely from museums and other places on the internet. Each image is 224 pixels long on the longest axis. Each image has 1 detailed piece of text. Image/text pairs are largely unique though there are duplicates, but I don't think there should be more than a few close duplicates. I guess I can crunch it again by computing phashes of the images... You are right that not every text is unique, I am now implementing an augmentation step where every time the sampler tries to get a label it gets 1 random piece of text out of 5. Each Image I have has 5 different valid captions written slightly differently. I need to make sure this is actually running though. |
Beta Was this translation helpful? Give feedback.
-
|
I haven't been able to try your idea out yet since the text tower isn't trivial to lock, I saw there's a PR open I'll try to implement that. Here is my current evaluation metrics per epoch after adding some more image augmentations and using learning rate 1e-6, wd 1e-7 and a short warm up: here it is compared to pretrained models: |
Beta Was this translation helpful? Give feedback.
-
|
@jn2clark I tried your second suggestion and sadly it didn't improve things. I'm going to go back to the drawing board and try cleaning up my dataset some more. What really confuses me is why the model performance jumps from atrocious to 30% Recall@1 when the model is exposed to like 20% of the dataset, and it then just gets worse and worse and worse slowly. My learning rates aren't even high according to the literature. I might even try a super slow learning rate like 1e-7 with Adam I suppose, though that seems like a strange approach. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I'm using the following:
with 224x224 images + detailed text pairs with around 2 million samples.
Here is my image retrieval benchmark with the pre-trained model:
and here are my results after 1 test epoch:
and here after 2 epochs:
Should I be using a different optimizer? Should I set WD to 0? any advice is appreciated.
I am training on a single H200 node which allows batch size of around 9000 given the args above.
Beta Was this translation helpful? Give feedback.
All reactions