Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

from nemo.collections.llm.gpt.data.alpaca import AlpacaDataModule
from nemo.collections.llm.gpt.data.chat import ChatDataModule
from nemo.collections.llm.gpt.data.cnn_dailymail import CNNDailyMailFineTuningDataModule
from nemo.collections.llm.gpt.data.dolly import DollyDataModule
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule, HFFineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule, HFMockDataModule
from nemo.collections.llm.gpt.data.mlperf_govreport import MLPerfGovReportDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
Expand All @@ -26,9 +27,11 @@
__all__ = [
"AlpacaDataModule",
"ChatDataModule",
"CNNDailyMailFineTuningDataModule",
"DollyDataModule",
"FineTuningDataModule",
"HFDatasetDataModule",
"HFFineTuningDataModule",
"HFMockDataModule",
"MLPerfGovReportDataModule",
"MockDataModule",
Expand Down
50 changes: 50 additions & 0 deletions nemo/collections/llm/gpt/data/cnn_dailymail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.llm.gpt.data.fine_tuning import HFFineTuningDataModule

class CNNDailyMailFineTuningDataModule(HFFineTuningDataModule):
"""A data module for fine-tuning on the CNN / Daily Mail dataset.

This class inherits from the `HFFineTuningDataModule` class including arguments for init and these methods.
"""

def _make_splits(self, dset, *args, **kwargs):
"""Maps train/validation/test to standard split names."""
save_splits = {
"training": dset.get("train"),
"validation": dset.get("validation"),
"test": dset.get("test"),
}
return save_splits

def _json_line_from_example(self, example, split_name, *args, **kwargs):
"""Extract data for summarization task."""
json_line = {
"input": example["article"],
"output": example["highlights"],
}
return json_line

@property
def dataset_name(self) -> str:
return "cnn_dailymail"

@property
def hf_load_dataset_kwargs(self) -> dict:
"""Retrieve 1.0.0 version of the dataset."""
kwargs = super().hf_load_dataset_kwargs | {
"name": "1.0.0",
}
return kwargs
153 changes: 152 additions & 1 deletion nemo/collections/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import math
import shutil
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import lightning.pytorch as pl
from datasets import DatasetDict, load_dataset
from torch.utils.data import DataLoader

from nemo.collections.common.tokenizers import AutoTokenizer
from nemo.collections.llm.gpt.data.core import create_sft_dataset
from nemo.collections.llm.gpt.data.core import create_sft_dataset, get_dataset_root
from nemo.lightning.data import WrappedDataLoader
from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging

Expand Down Expand Up @@ -344,3 +348,150 @@ def _extract_tokenizer_model_name(self) -> str:
else:
tokenizer_model_name = f"unknown_tokenizer_{hash(self.tokenizer)}"
return tokenizer_model_name


class HFFineTuningDataModule(FineTuningDataModule, IOMixin):
"""A generic data module for downloading and preprocessing HF datasets for fine-tuning.

This class inherits from the `FineTuningDataModule` class; see this class for further argument details.
It handles data download, preprocessing, splitting, and preparing the data in a format suitable for training, validation, and testing.

Args:
dataset_root (Optional[Union[str, Path]]): The root directory containing the training,
validation, and test data. Defaults to None, which by default downloads the data.
force_redownload (bool, optional): Whether to force re-download the dataset even if it
exists locally. Defaults to False.
delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing.
Defaults to True.
"""

def __init__(
self,
dataset_root: Optional[Union[str, Path]] = None,
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
force_redownload: bool = False,
delete_raw: bool = True,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
dataset_kwargs: Optional[Dict[str, Any]] = None,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw

super().__init__(
dataset_root=dataset_root if dataset_root is not None else get_dataset_root(self.dataset_name),
seq_length=seq_length,
tokenizer=tokenizer,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
seed=seed,
memmap_workers=memmap_workers,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
packed_sequence_specs=packed_sequence_specs,
dataset_kwargs=dataset_kwargs,
)

def prepare_data(self) -> None:
"""Download and preprocess data as needed."""
# if train file is specified, no need to do anything
if not self.train_path.exists() or self.force_redownload:
dset = self._download_data()
self._preprocess_and_split_data(dset)
super().prepare_data()

def _download_data(self):
"""Download this dataset from HF."""
logging.info(f"Downloading {self.__class__.__name__}...")
return load_dataset(
self.dataset_hf_path,
**self.hf_load_dataset_kwargs,
)

def _preprocess_and_split_data(
self, dset: DatasetDict, *args, **kwargs,
):
"""Preprocesses and splits the downloaded dataset into training, validation, an
test sets.

Args:
dset (Dataset or DatasetDict): The downloaded dataset object.
"""
logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...")

save_splits = self._make_splits(dset, *args, **kwargs)

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.jsonl"

with output_file.open("w", encoding="utf-8") as f:
for example in dataset:
f.write(json.dumps(self._json_line_from_example(example, split_name, *args, **kwargs)) + "\n")

logging.info(f"{split_name} split saved to {output_file}")

if self.delete_raw:
for p in self.dataset_root.iterdir():
if p.is_dir():
shutil.rmtree(p)
elif '.jsonl' not in str(p.name):
p.unlink()

def _make_splits(self, dset, *args, **kwargs):
"""Assemble and return a dict with training, validation, and test splits as needed from the given dataset.

To be overridden by subclasses.

Args:
dset (Dataset or DatasetDict): The downloaded dataset object.

Returns:
A dictionary mapping string keys for jsonl output file names for each split to that dataset subset.
"""
raise NotImplementedError()

def _json_line_from_example(self, example, split_name, *args, **kwargs):
"""Generate and return a dict with input and output data as needed from the given example.

To be overridden by subclasses.

Args:
example (Dict): A data element from the dataset to be processed.
split_name (str): The string key for this split of the dataset.

Returns:
A dictionary mapping string keys to raw data, to be written in json lines format.
"""
raise NotImplementedError()

@property
def dataset_name(self) -> str:
"""String property. Dataset will be written under this name within `dataset_root`.

To be overridden by subclasses
"""
raise NotImplementedError()

@property
def dataset_hf_path(self) -> str:
"""String property. Dataset will be retrieved from HF with this path. Defaults to match `dataset_name`."""
return self.dataset_name

@property
def hf_load_dataset_kwargs(self) -> dict:
"""Additional keyword args to pass to `dataset_load` when retrieving the dataset from HF."""
kwargs = {
"cache_dir": str(self.dataset_root),
"download_mode": "force_redownload" if self.force_redownload else None,
}
return kwargs
20 changes: 10 additions & 10 deletions nemo/collections/llm/gpt/data/mlperf_govreport.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def __init__(
dataset_kwargs=dataset_kwargs,
)

if self.packed_sequence_size != self.seq_length:
if self.packed_sequence_size != self.seq_length or self.packed_sequence_size > 8192:
raise ValueError(
f"{self.__class__.__name__} requires `packed_sequence_specs.packed_sequence_size` to be nonzero "
f"and equal to `seq_length`. Instead got packed_sequence_size = {self.packed_sequence_size} "
f"and seq_length = {self.seq_length}"
f"{self.__class__.__name__} requires `packed_sequence_specs.packed_sequence_size` to be nonzero, "
f"less than or equal to the max of 8192, and equal to `seq_length`. Instead got "
f"packed_sequence_size = {self.packed_sequence_size} and seq_length = {self.seq_length}"
)

def prepare_data(self) -> None:
Expand Down Expand Up @@ -133,11 +133,11 @@ def _preprocess_and_split_data(
save_splits['test'] = split_dataset['train']

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.npy"
output_file = self.dataset_root / f"{split_name}_{self.seq_length}.npy"
processed_data = [
{
"input_ids": example["input_ids"],
"loss_mask": [int(x != -100) for x in example["labels"]],
"input_ids": list(example["input_ids"])[:self.seq_length],
"loss_mask": [int(x != -100) for x in example["labels"]][:self.seq_length],
"seq_start_id": [0],
}
for example in dataset
Expand All @@ -156,17 +156,17 @@ def _preprocess_and_split_data(
@property
def train_path(self) -> Path:
"""Path to training dataset file"""
return self.dataset_root / "training.npy"
return self.dataset_root / f"training_{self.seq_length}.npy"

@property
def validation_path(self) -> Path:
"""Path to validation dataset file"""
return self.dataset_root / "validation.npy"
return self.dataset_root / f"validation_{self.seq_length}.npy"

@property
def test_path(self) -> Path:
"""Path to test dataset file"""
return self.dataset_root / "test.npy"
return self.dataset_root / f"test_{self.seq_length}.npy"

@property
def default_pack_path(self) -> Path:
Expand Down
Loading
Loading