Skip to content

Commit 4014bdd

Browse files
authored
Ch06 classifier function asserts (#703)
1 parent f5bc863 commit 4014bdd

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

ch06/01_main-chapter-code/ch06.ipynb

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2353,7 +2353,17 @@
23532353
"\n",
23542354
" # Truncate sequences if they too long\n",
23552355
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",
2356-
"\n",
2356+
" assert max_length is not None, (\n",
2357+
" \"max_length must be specified. If you want to use the full model context, \"\n",
2358+
" \"pass max_length=model.pos_emb.weight.shape[0].\"\n",
2359+
" )\n",
2360+
" assert max_length <= supported_context_length, (\n",
2361+
" f\"max_length ({max_length}) exceeds model's supported context length ({supported_context_length}).\"\n",
2362+
" ) \n",
2363+
" # Alternatively, a more robust version is the following one, which handles the max_length=None case better\n",
2364+
" # max_len = min(max_length,supported_context_length) if max_length else supported_context_length\n",
2365+
" # input_ids = input_ids[:max_len]\n",
2366+
" \n",
23572367
" # Pad sequences to the longest sequence\n",
23582368
" input_ids += [pad_token_id] * (max_length - len(input_ids))\n",
23592369
" input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension\n",

0 commit comments

Comments
 (0)