Skip to content

feat: support memmap dataset with object store via Multi-Storage Client #12870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 64 additions & 20 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,24 @@
import pickle
import time
from functools import lru_cache, partial
from typing import Callable, List, Optional, Type
from typing import TYPE_CHECKING, Callable, List, Optional, Type

import numpy as np
import torch

from nemo.core import Dataset
from nemo.utils import AppState, logging

try:
import multistorageclient

MULTISTORAGECLIENT_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
MULTISTORAGECLIENT_AVAILABLE = False

if TYPE_CHECKING:
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec

__all__ = ["TextMemMapDataset", "CSVMemMapDataset", "build_index_files"]
__idx_version__ = "0.2" # index file version
__idx_suffix__ = "idx" # index file suffix
Expand All @@ -40,7 +50,10 @@
Returns a 1D array of ints.
"""
# use memmap to read file
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
if MULTISTORAGECLIENT_AVAILABLE:
mdata = multistorageclient.numpy.memmap(fn, dtype=np.uint8, mode="r")
else:
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
# find newline positions
midx = np.where(mdata == newline_int)[0]
midx_dtype = midx.dtype
Expand Down Expand Up @@ -113,7 +126,7 @@
if sort_dataset_paths:
self._files_list = sorted(self._files_list)

logging.info(f"Building data files")
logging.info("Building data files")
# load all files into memmap
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()

Expand Down Expand Up @@ -155,11 +168,12 @@
if is_distributed and not _lightning_prepare_data():
torch.distributed.barrier()

logging.info(f"Loading data files")
logging.info("Loading data files")
start_time = time.time()
mdata_midx_list = [self.load_file(fn, index_mapping_dir) for fn in self._files_list]
logging.info(
f"Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
f"Time loading {len(mdata_midx_list)} mem-mapped files: "
f"{datetime.timedelta(seconds=time.time() - start_time)}"
)

logging.info("Computing global indices")
Expand Down Expand Up @@ -213,15 +227,17 @@
data = self._build_data_from_text(sample)
except Exception as e:
logging.error(
f"Error while building data from text, possible issue with sample expected format (see offending sample below): {e}"
"Error while building data from text, possible issue with sample expected format "
f"(see offending sample below): {e}"
)
logging.error(f"sample: {sample}, file_id: {file_id}, file_idx: {file_idx}, i: {i}, j: {j}")
raise e

return data

def _fetch_sample_from_memmap(self, mdata, i, j):
"""Fetchs the text sample. Can be overriden by child-classes to support loading of partial samples and alternative decode methods"""
"""Fetchs the text sample. Can be overriden by child-classes to support loading of partial samples
and alternative decode methods"""
# load text sample by slicing memmap data[i:j]
text = mdata[i:j].tobytes().decode("utf-8")

Expand Down Expand Up @@ -250,18 +266,28 @@
idx_fn = _index_fn(fn, index_mapping_dir)

# create data map
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
if MULTISTORAGECLIENT_AVAILABLE:
mdata = multistorageclient.numpy.memmap(fn, dtype=np.uint8, mode="r")
else:
mdata = np.memmap(fn, dtype=np.uint8, mode="r")

if _index_file_exists(idx_fn):
# load index file into memory map
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
if MULTISTORAGECLIENT_AVAILABLE:
midx = multistorageclient.numpy.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
else:
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
# test for header
if len(midx) < self._header_lines:
raise RuntimeError(f"Missing header, expected {self._header_lines} header lines")

# load meta info
with open(idx_fn + ".info", "rb") as fp:
idx_info_dict = pickle.load(fp)
if MULTISTORAGECLIENT_AVAILABLE:
with multistorageclient.open(idx_fn + ".info", "rb") as fp:
idx_info_dict = multistorageclient.pickle.load(fp)
else:
with open(idx_fn + ".info", "rb") as fp:
idx_info_dict = pickle.load(fp)
# test for mismatch in expected newline_int
if "newline_int" in idx_info_dict:
newline_int = idx_info_dict["newline_int"]
Expand All @@ -274,7 +300,8 @@
idx_version = idx_info_dict.get("version", "0.0")
if __idx_version__ != idx_version:
raise RuntimeError(
f"Version mismatch: Please delete existing '.{__idx_suffix__}' files. Expected version = {__idx_version__}, but file version = {idx_version}. File path = {idx_fn}"
f"Version mismatch: Please delete existing '.{__idx_suffix__}' files. Expected version = "
f"{__idx_version__}, but file version = {idx_version}. File path = {idx_fn}"
)
else:
raise ValueError(
Expand Down Expand Up @@ -438,10 +465,14 @@

def _index_file_exists(idx_fn):
"""Helper function to test if index file exists"""
if os.path.exists(idx_fn + ".npy") and os.path.exists(idx_fn + ".info"):
return True
is_exists = False
if MULTISTORAGECLIENT_AVAILABLE:
is_exists = multistorageclient.os.path.exists(idx_fn + ".npy") and multistorageclient.os.path.exists(
idx_fn + ".info"
)
else:
return False
is_exists = os.path.exists(idx_fn + ".npy") and os.path.exists(idx_fn + ".info")
return is_exists


def _index_fn(fn: str, index_mapping_dir: str) -> str:
Expand Down Expand Up @@ -504,9 +535,16 @@

# save index as numpy array to enable memmap reading
logging.info(f"Saving idx file = {idx_fn}.npy")
np.save(idx_fn + ".npy", midx, allow_pickle=True)
if MULTISTORAGECLIENT_AVAILABLE:
multistorageclient.numpy.save(idx_fn + ".npy", midx, allow_pickle=True)
else:
np.save(idx_fn + ".npy", midx, allow_pickle=True)

logging.info(f"Saving metadata file = {idx_fn}.info")
pickle.dump(data, open(idx_fn + ".info", "wb"))
if MULTISTORAGECLIENT_AVAILABLE:
multistorageclient.pickle.dump(data, idx_fn + ".info")
else:
pickle.dump(data, open(idx_fn + ".info", "wb"))

return True

Expand Down Expand Up @@ -541,7 +579,8 @@
)

logging.info(
f"Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
f"Time building {sum(build_status)} / {len(build_status)} mem-mapped files: "
f"{datetime.timedelta(seconds=time.time() - start_time)}"
)


Expand Down Expand Up @@ -606,7 +645,8 @@
cache_maxsize (int): Maximum size of the blocks cache for the get_sample_block function.
seed (int): Seed for the random number generator used for shuffling.
shuffle (bool): Whether to shuffle the samples.
truncate_to_block_boundary (bool): Whether to truncate the last block to the block boundary (could drop samples).
truncate_to_block_boundary (bool): Whether to truncate the last block to the block boundary
(could drop samples).
"""
self.dataset_size = dataset_size
self.num_samples = num_samples
Expand Down Expand Up @@ -660,7 +700,11 @@
self.get_sample_block = lru_cache(maxsize=cache_maxsize, typed=False)(self.get_sample_block)

def __str__(self):
return f"OnlineSampleMapping(dataset_size={self.dataset_size}, num_samples={self.num_samples}, block_size={self.block_size}, cache_maxsize={self.cache_maxsize}, seed={self.seed}, shuffle={self.shuffle}, truncate_to_block_boundary={self.truncate_to_block_boundary})"
return (
f"OnlineSampleMapping(dataset_size={self.dataset_size}, num_samples={self.num_samples}, "
f"block_size={self.block_size}, cache_maxsize={self.cache_maxsize}, seed={self.seed}, "
f"shuffle={self.shuffle}, truncate_to_block_boundary={self.truncate_to_block_boundary})"
)

def __getitem__(self, idx: int) -> int:
# handle slices
Expand Down
Loading