Started with Batch Size = 2, Seq Length = 1024.
Running on a A4000 (I know it is pretty old, but that's all I have access to :( )
Training for 1 epoch
- Without any Optimizations - Time to train: 48.72043180465698 seconds
- float32_matmul_precision - high - Allowing for TF32 - Time to train: 29.83514142036438 seconds
- torch.autocast to bfloat16 + (2) - Time to train: 24.927905797958374 seconds
- With torch.compile + (3) - Time to train: 32.85438013076782 seconds
For 5 epochs,
- With torch.compile - Time to train: 81.19576597213745 seconds
- with torch.autocast to bfloat16 and float32 matmul precision to high - Time to train: 125.25839257240295 seconds
- With torch.compile and Flash Attention - Time to train: 69.03489518165588 seconds
- Change Vocab Size to a nice number + (3) - Time to train: 65.71632099151611 seconds
After Adding LR Schedulers, Weight Decay and Gradient Norm Clipping, Training for 5 epochs
- Without Fused AdamW - Time to train: 62.56659150123596 seconds
- With Fused AdamW - Time to train: 57.591166973114014 seconds, Avg Loss - 5.64244685606523
Added QK-norm, Training for 5 Epochs, Time - 69.68068528175354 seconds. Avg loss - 5.588010857321999
Then Added RoPE Embeddings (YaRN is available, but as of now the seq length is fixed), Now, Training for 5 Epochs, Time - 70.93886852264404 seconds. Avg Loss - 4.063881944887566
Added ReLU^2 Activation Instead of GeLU Activation Training for 5 Epochs, Time - 68.68175554275513 seconds. Avg Loss - 3.8532666105212585