|
103 | 103 | " return self.input_ids[idx], self.target_ids[idx]\n",
|
104 | 104 | "\n",
|
105 | 105 | "\n",
|
106 |
| - "def create_dataloader_v1(txt, batch_size=4, max_length=256, \n", |
107 |
| - " stride=128, shuffle=True, drop_last=True, num_workers=0):\n", |
| 106 | + "def create_dataloader_v1(txt, batch_size, max_length, stride, shuffle=True, drop_last=True, num_workers=0):\n", |
108 | 107 | " # Initialize the tokenizer\n",
|
109 | 108 | " tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
110 | 109 | "\n",
|
|
121 | 120 | "with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
|
122 | 121 | " raw_text = f.read()\n",
|
123 | 122 | "\n",
|
124 |
| - "tokenizer = tiktoken.get_encoding(\"gpt2\")\n", |
125 |
| - "encoded_text = tokenizer.encode(raw_text)\n", |
126 |
| - "\n", |
127 | 123 | "vocab_size = 50257\n",
|
128 | 124 | "output_dim = 256\n",
|
129 | 125 | "context_length = 1024\n",
|
|
132 | 128 | "token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
|
133 | 129 | "pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)\n",
|
134 | 130 | "\n",
|
| 131 | + "batch_size = 8\n", |
135 | 132 | "max_length = 4\n",
|
136 |
| - "dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=max_length, stride=max_length)" |
| 133 | + "dataloader = create_dataloader_v1(raw_text, batch_size=batch_size, max_length=max_length, stride=max_length)" |
137 | 134 | ]
|
138 | 135 | },
|
139 | 136 | {
|
|
0 commit comments