Skip to content

Commit

Permalink
Merge pull request #52 from jmisilo/51-retraining
Browse files Browse the repository at this point in the history
51 retraining
  • Loading branch information
jmisilo authored Nov 6, 2022
2 parents e7519d4 + 2b22c04 commit ef9922b
Show file tree
Hide file tree
Showing 10 changed files with 14 additions and 11 deletions.
Binary file modified examples/23012796.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/36979.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed examples/3787801.jpg
Binary file not shown.
Binary file removed examples/7757242158.jpg
Binary file not shown.
Binary file added examples/89407459.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/loss_lr.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 10 additions & 7 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ The Model uses prefixes as in the [ClipCap](https://arxiv.org/abs/2111.09734) pa

The Model was trained with a frozen CLIP, a fully trained Mapping Module (6x Transformer Encoder Layers) and with partially frozen GPT-2 (the first and last 14 layers were trained).

The training process was carried out using the [Kaggle](https://www.kaggle.com/) P100 GPU. Training time is about 2 x 11h (106 epochs) with a linearly changing learning rate (from 0 to 0.0001908) and batch size 64. Originally, the Model was supposed to be trained longer - which results in a non-standard LR. *I also tried a longer training session (150 epochs), but overtraining was noticeable.*
The training process was carried out using the [Kaggle](https://www.kaggle.com/) P100 GPU. Training time - about 3 x 11h (150 epochs) with a linear learning rate warmup (max LR `3e-3`) and batch size 64.

### Example results

![Example1](./examples/23012796.jpg)
#### Loss and Learning Rate during training

![Example2](./examples/3787801.jpg)
![LOSSxLR](./examples/loss_lr.jpg)

![Example3](./examples/7757242158.jpg)
### Example results

As I said, the goal was to test the Model's ability to recognize the situation. In the next phase of the experiments, I will try to improve the Model process and parameters to achieve better captions with the same dataset.
![Example1](./examples/23012796.jpg)
![Example2](./examples/36979.jpg)
![Example3](./examples/89407459.jpg)

### Usage

Expand All @@ -36,7 +36,10 @@ Create environment and install requirements:

```bash
python -m venv venv
# for windows
.\venv\Scripts\activate
# for linux/mac
source venv/bin/activate

pip install -r requirements.txt
```
Expand Down
2 changes: 1 addition & 1 deletion src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
start_epoch, total_train_loss, total_valid_loss = (
load_ckp(ckp_path, model, optimizer, scheduler, scaler, device)
if os.path.isfile(ckp_path) else
0, [], []
(0, [], [])
)

# build train model process with experiment tracking from wandb
Expand Down
4 changes: 2 additions & 2 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class Config:
num_workers: int = 2
train_size: int = 0.84
val_size: int = 0.13
epochs: int = 200
lr: int = 6e-3
epochs: int = 150
lr: int = 3e-3
k: float = 0.33
batch_size_exp: int = 6
ep_len: int = 4
Expand Down
2 changes: 1 addition & 1 deletion src/utils/load_ckp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ def download_weights(checkpoint_fpath):
Downloads weights from Google Drive.
'''

gdown.download('https://drive.google.com/uc?id=1lEufQVOETFEIhPdFDYaez31uroq_5Lby', checkpoint_fpath, quiet=False)
gdown.download('https://drive.google.com/uc?id=10ieSMMJzE9EeiPIF3CMzeT4timiQTjHV', checkpoint_fpath, quiet=False)

0 comments on commit ef9922b

Please sign in to comment.