Skip to content

Commit 7abef3f

Browse files
committed
feat: add object store support with multistorageclient for memmap dataset
1 parent b685967 commit 7abef3f

File tree

1 file changed

+38
-10
lines changed

1 file changed

+38
-10
lines changed

Diff for: nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py

+38-10
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
from nemo.core import Dataset
2828
from nemo.utils import AppState, logging
2929

30+
try:
31+
import multistorageclient
32+
MULTISTORAGECLIENT_AVAILABLE = True
33+
except (ImportError, ModuleNotFoundError):
34+
MULTISTORAGECLIENT_AVAILABLE = False
35+
3036
__all__ = ["TextMemMapDataset", "CSVMemMapDataset", "build_index_files"]
3137
__idx_version__ = "0.2" # index file version
3238
__idx_suffix__ = "idx" # index file suffix
@@ -40,7 +46,10 @@ def _build_index_from_memdata(fn, newline_int):
4046
Returns a 1D array of ints.
4147
"""
4248
# use memmap to read file
43-
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
49+
if MULTISTORAGECLIENT_AVAILABLE:
50+
mdata = multistorageclient.numpy.memmap(fn, dtype=np.uint8, mode="r")
51+
else:
52+
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
4453
# find newline positions
4554
midx = np.where(mdata == newline_int)[0]
4655
midx_dtype = midx.dtype
@@ -250,18 +259,28 @@ def load_file(self, fn, index_mapping_dir: Optional[str] = None):
250259
idx_fn = _index_fn(fn, index_mapping_dir)
251260

252261
# create data map
253-
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
262+
if MULTISTORAGECLIENT_AVAILABLE:
263+
mdata = multistorageclient.numpy.memmap(fn, dtype=np.uint8, mode="r")
264+
else:
265+
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
254266

255267
if _index_file_exists(idx_fn):
256268
# load index file into memory map
257-
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
269+
if MULTISTORAGECLIENT_AVAILABLE:
270+
midx = multistorageclient.numpy.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
271+
else:
272+
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
258273
# test for header
259274
if len(midx) < self._header_lines:
260275
raise RuntimeError(f"Missing header, expected {self._header_lines} header lines")
261276

262277
# load meta info
263-
with open(idx_fn + ".info", "rb") as fp:
264-
idx_info_dict = pickle.load(fp)
278+
if MULTISTORAGECLIENT_AVAILABLE:
279+
with multistorageclient.open(idx_fn + ".info", "rb") as fp:
280+
idx_info_dict = multistorageclient.pickle.load(fp)
281+
else:
282+
with open(idx_fn + ".info", "rb") as fp:
283+
idx_info_dict = pickle.load(fp)
265284
# test for mismatch in expected newline_int
266285
if "newline_int" in idx_info_dict:
267286
newline_int = idx_info_dict["newline_int"]
@@ -438,10 +457,12 @@ def _build_data_from_text(self, text):
438457

439458
def _index_file_exists(idx_fn):
440459
"""Helper function to test if index file exists"""
441-
if os.path.exists(idx_fn + ".npy") and os.path.exists(idx_fn + ".info"):
442-
return True
460+
is_exists = False
461+
if MULTISTORAGECLIENT_AVAILABLE:
462+
is_exists = multistorageclient.os.path.exists(idx_fn + ".npy") and multistorageclient.os.path.exists(idx_fn + ".info")
443463
else:
444-
return False
464+
is_exists = os.path.exists(idx_fn + ".npy") and os.path.exists(idx_fn + ".info")
465+
return is_exists
445466

446467

447468
def _index_fn(fn: str, index_mapping_dir: str) -> str:
@@ -504,9 +525,16 @@ def _build_memmap_index_files(newline_int, build_index_fn, fn, index_mapping_dir
504525

505526
# save index as numpy array to enable memmap reading
506527
logging.info(f"Saving idx file = {idx_fn}.npy")
507-
np.save(idx_fn + ".npy", midx, allow_pickle=True)
528+
if MULTISTORAGECLIENT_AVAILABLE:
529+
multistorageclient.numpy.save(idx_fn + ".npy", midx, allow_pickle=True)
530+
else:
531+
np.save(idx_fn + ".npy", midx, allow_pickle=True)
532+
508533
logging.info(f"Saving metadata file = {idx_fn}.info")
509-
pickle.dump(data, open(idx_fn + ".info", "wb"))
534+
if MULTISTORAGECLIENT_AVAILABLE:
535+
multistorageclient.pickle.dump(data, idx_fn + ".info")
536+
else:
537+
pickle.dump(data, open(idx_fn + ".info", "wb"))
510538

511539
return True
512540

0 commit comments

Comments
 (0)