Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,11 @@ hf_eval_split: ''
hf_eval_files: ''
hf_access_token: ''
# for Grain input pipeline (dataset_type=grain)
# Path to grain data files. Can be a single pattern or multiple patterns with weights.
# For multiple patterns, use semicolon (;) to separate and colon (:) to specify weights.
# Example: "path/to/data1.array_record*:0.3;path/to/data2.array_record*:0.7"
# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported.
# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline
grain_train_files: ''
grain_eval_files: ''
grain_file_type: 'arrayrecord' # arrayrecord or parquet
Expand Down
24 changes: 20 additions & 4 deletions MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
from MaxText import tokenizer


def find_data_files(data_file_pattern):
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
return data_files


def get_datasets(
data_file_pattern,
data_file_type,
Expand All @@ -44,17 +51,26 @@ def get_datasets(
grain_worker_count,
):
"""Load dataset from array_record files for using with grain"""
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
if data_file_type == "arrayrecord":
dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files))
if ";" in data_file_pattern:
data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")])
assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match"
weights = [float(weight) for weight in weights]
weights = [round(weight / sum(weights), 4) for weight in weights]
dataset_list = [
grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns
]
dataset = grain.MapDataset.mix(dataset_list, weights)
else:
data_files = find_data_files(data_file_pattern)
dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files))
if shuffle:
dataset = dataset.shuffle(seed=shuffle_seed)
dataset = dataset.repeat(num_epoch)
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
dataset = dataset.to_iter_dataset()
elif data_file_type == "parquet":
data_files = find_data_files(data_file_pattern)
dataset = grain.MapDataset.source(data_files)
if shuffle:
dataset = dataset.shuffle(seed=shuffle_seed)
Expand Down
2 changes: 1 addition & 1 deletion MaxText/input_pipeline/_input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def map(self, element):

@dataclasses.dataclass
class Rekey(grain.MapTransform):
"""Rname keys according to a mappign dict"""
"""Rename keys according to a mapping dict"""

def __init__(self, mapping_dict, keep_old_keys=False):
self.mapping_dict = mapping_dict
Expand Down
36 changes: 36 additions & 0 deletions MaxText/tests/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,42 @@ def get_first_batch(iterator):
self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all())


class GrainArrayRecordProcessingWithMultiSourceBlendingTest(GrainArrayRecordProcessingTest):

def setUp(self):
super().setUp()
temp_dir = tempfile.gettempdir()
# We use the same dataset for testing, but you can use different datasets by changing the file patterns.
grain_train_files = [
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3",
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7",
]
grain_train_files = ";".join(grain_train_files)
self.config = pyconfig.initialize(
[sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")],
per_device_batch_size=1,
run_name="test",
mesh_axes=["data"],
logical_axis_rules=[["batch", "data"]],
data_sharding=["data"],
base_output_directory="gs://max-experiments/",
dataset_type="grain",
grain_train_files=grain_train_files,
tokenizer_path=os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer"),
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
self.config.data_sharding,
self.config.global_batch_size_to_load,
self.config.global_batch_size_to_train_on,
self.config.max_target_length,
self.mesh,
)
self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)


class GrainParquetProcessingTest(unittest.TestCase):

@classmethod
Expand Down
15 changes: 13 additions & 2 deletions getting_started/Data_Input_Pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,18 @@ bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$BUCKET_NAME MOUNT_PATH=$MOUNT_PATH [FI
```
3. Set `dataset_type=grain` and set `grain_train_files` to match the ArrayRecord files via a local path since the bucket has been mounted.
4. Tune `grain_worker_count` for performance. This parameter controls the number of child process used by Grain (more details in [behind_the_scene](https://github.com/google/grain/blob/main/docs/behind_the_scenes.md), [code](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, please check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/setup_gcsfuse.sh) to avoid gcsfuse throttling.
5. Example command:

5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as separator and colon (:) for weights. The weights will be automatically normalized to sum to 1.0. For example:
```
# Blend two data sources with 30% from first source and 70% from second source
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:0.3;/tmp/gcsfuse/dataset2.array_record*:0.7

# Blend three data sources with equal weights (will be normalized to 0.33 each)
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:1;/tmp/gcsfuse/dataset2.array_record*:1;/tmp/gcsfuse/dataset3.array_record*:1
```
Note: When using multiple data sources, only ArrayRecord format is supported.

6. Example command:
```
bash setup_gcsfuse.sh \
DATASET_GCS_BUCKET=maxtext-dataset \
Expand All @@ -114,7 +125,7 @@ grain_file_type=arrayrecord \
grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \
grain_worker_count=2
```
6. Using validation set for eval
7. Using validation set for eval
When setting eval_interval > 0, eval will be run with a specified eval dataset. Example config:
```
eval_interval: 10000
Expand Down
Loading