|
552 | 552 | "source": [
|
553 | 553 | "from previous_chapters import evaluate_model, generate_and_print_sample\n",
|
554 | 554 | "\n",
|
| 555 | + "BOOK_VERSION = True\n", |
| 556 | + "\n", |
555 | 557 | "\n",
|
556 | 558 | "def train_model(model, train_loader, val_loader, optimizer, device,\n",
|
557 | 559 | " n_epochs, eval_freq, eval_iter, start_context, tokenizer,\n",
|
|
595 | 597 | " loss.backward()\n",
|
596 | 598 | "\n",
|
597 | 599 | " # Apply gradient clipping after the warmup phase to avoid exploding gradients\n",
|
598 |
| - " if global_step > warmup_steps:\n", |
599 |
| - " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", |
600 |
| - " \n", |
| 600 | + "\n", |
| 601 | + " if BOOK_VERSION:\n", |
| 602 | + " if global_step > warmup_steps:\n", |
| 603 | + " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) \n", |
| 604 | + " else:\n", |
| 605 | + " if global_step >= warmup_steps: # the book originally used global_step > warmup_steps, which lead to a skipped clipping step after warmup\n", |
| 606 | + " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", |
| 607 | + " \n", |
601 | 608 | " optimizer.step()\n",
|
602 | 609 | " tokens_seen += input_batch.numel()\n",
|
603 | 610 | "\n",
|
|
691 | 698 | "model = GPTModel(GPT_CONFIG_124M)\n",
|
692 | 699 | "model.to(device)\n",
|
693 | 700 | "\n",
|
694 |
| - "peak_lr = 5e-4\n", |
695 |
| - "optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.1)\n", |
| 701 | + "peak_lr = 0.001 # this was originally set to 5e-4 in the book by mistake\n", |
| 702 | + "optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=0.1) # the book accidentally omitted the lr assignment\n", |
696 | 703 | "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
697 | 704 | "\n",
|
698 | 705 | "n_epochs = 15\n",
|
|
817 | 824 | "name": "python",
|
818 | 825 | "nbconvert_exporter": "python",
|
819 | 826 | "pygments_lexer": "ipython3",
|
820 |
| - "version": "3.10.6" |
| 827 | + "version": "3.11.4" |
821 | 828 | }
|
822 | 829 | },
|
823 | 830 | "nbformat": 4,
|
|
0 commit comments