|
103 | 103 | " return self.input_ids[idx], self.target_ids[idx]\n",
|
104 | 104 | "\n",
|
105 | 105 | "\n",
|
106 |
| - "def create_dataloader_v1(txt, batch_size=4, max_length=256, \n", |
107 |
| - " stride=128, shuffle=True, drop_last=True, num_workers=0):\n", |
| 106 | + "def create_dataloader_v1(txt, batch_size, max_length, stride,\n", |
| 107 | + " shuffle=True, drop_last=True, num_workers=0):\n", |
108 | 108 | " # Initialize the tokenizer\n",
|
109 | 109 | " tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
110 | 110 | "\n",
|
|
121 | 121 | "with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
|
122 | 122 | " raw_text = f.read()\n",
|
123 | 123 | "\n",
|
124 |
| - "tokenizer = tiktoken.get_encoding(\"gpt2\")\n", |
125 |
| - "encoded_text = tokenizer.encode(raw_text)\n", |
126 |
| - "\n", |
127 | 124 | "vocab_size = 50257\n",
|
128 | 125 | "output_dim = 256\n",
|
129 | 126 | "context_length = 1024\n",
|
|
132 | 129 | "token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
|
133 | 130 | "pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)\n",
|
134 | 131 | "\n",
|
| 132 | + "batch_size = 8\n", |
135 | 133 | "max_length = 4\n",
|
136 |
| - "dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=max_length, stride=max_length)" |
| 134 | + "dataloader = create_dataloader_v1(\n", |
| 135 | + " raw_text,\n", |
| 136 | + " batch_size=batch_size,\n", |
| 137 | + " max_length=max_length,\n", |
| 138 | + " stride=max_length\n", |
| 139 | + ")" |
137 | 140 | ]
|
138 | 141 | },
|
139 | 142 | {
|
|
189 | 192 | "name": "python",
|
190 | 193 | "nbconvert_exporter": "python",
|
191 | 194 | "pygments_lexer": "ipython3",
|
192 |
| - "version": "3.10.6" |
| 195 | + "version": "3.11.4" |
193 | 196 | }
|
194 | 197 | },
|
195 | 198 | "nbformat": 4,
|
|
0 commit comments