Skip to content

Commit 67e0680

Browse files
authored
Disable mask saving as weight in Llama 3 model (#604)
* Disable mask saving as weight * update pixi * update pixi
1 parent f143465 commit 67e0680

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

.github/workflows/basic-tests-pixi.yml

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ jobs:
4242
- name: List installed packages
4343
run: |
4444
pixi list --environment tests
45+
pixi run --environment tests pip install "huggingface-hub>=0.30.0,<1.0"
4546
4647
- name: Test Selected Python Scripts
4748
shell: pixi run --environment tests bash -e {0}

ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb

+4-1
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,10 @@
368368
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
369369
"\n",
370370
" # Reusuable utilities\n",
371-
" self.register_buffer(\"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool())\n",
371+
" self.register_buffer(\n",
372+
" \"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool(),\n",
373+
" persistent=False\n",
374+
" )\n",
372375
" cfg[\"rope_base\"] = rescale_theta(\n",
373376
" cfg[\"rope_base\"],\n",
374377
" cfg[\"orig_context_length\"],\n",

ch05/07_gpt_to_llama/standalone-llama32.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,10 @@
266266
"\n",
267267
" # Fetch buffers using SharedBuffers\n",
268268
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
269-
" self.register_buffer(\"mask\", mask)\n",
269+
" self.register_buffer(\"mask\", mask, persistent=False)\n",
270270
"\n",
271-
" self.register_buffer(\"cos\", cos)\n",
272-
" self.register_buffer(\"sin\", sin)\n",
271+
" self.register_buffer(\"cos\", cos, persistent=False)\n",
272+
" self.register_buffer(\"sin\", sin, persistent=False)\n",
273273
"\n",
274274
" def forward(self, x):\n",
275275
" b, num_tokens, d_in = x.shape\n",

0 commit comments

Comments
 (0)