Skip to content

Commit 44f56df

Browse files
committed
fix tokenizer call
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent a5cf213 commit 44f56df

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

nemo/collections/llm/gpt/data/hf_dataset.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -442,20 +442,20 @@ def preprocess_dataset(tokenizer, max_length, dataset, seed=42):
442442
print("Preprocessing dataset...")
443443
dataset = dataset.map(HellaSwagHFDataModule.process_doc)
444444

445-
def preprocess_batch(batch, tokenizer, max_length):
446-
ans = tokenizer(
447-
batch["text"],
448-
max_length=max_length,
449-
truncation=True,
445+
def preprocess(example, tokenizer, max_length):
446+
input_ids = tokenizer.text_to_ids(example["text"])
447+
if max_length > 0:
448+
input_ids = input_ids[:max_length]
449+
return dict(
450+
input_ids=input_ids,
451+
labels=input_ids[1:] + [-100]
450452
)
451-
ans["labels"] = [x[1:] + [-100] for x in ans["input_ids"]]
452-
return ans
453453

454-
# Apply preprocessing to each batch of the dataset & and remove "conversations" and "text" fields.
455-
_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
454+
# Apply preprocessing to each example of the dataset & and remove "conversations" and "text" fields.
455+
_preprocessing_function = partial(preprocess, max_length=max_length, tokenizer=tokenizer)
456456
dataset = dataset.map(
457457
_preprocessing_function,
458-
batched=True,
458+
batched=False,
459459
).select_columns(["input_ids", "attention_mask", "labels"])
460460

461461
# Shuffle dataset.

0 commit comments

Comments
 (0)