Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,11 @@ hf_eval_split: ''
hf_eval_files: ''
hf_access_token: ''
# for Grain input pipeline (dataset_type=grain)
# Path to grain data files. Can be a single pattern or multiple patterns with weights.
# For multiple patterns, use semicolon (;) to separate and colon (:) to specify weights.
# Example: "path/to/data1.array_record*:0.3;path/to/data2.array_record*:0.7"
# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported.
# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline
grain_train_files: ''
grain_eval_files: ''
grain_file_type: 'arrayrecord' # arrayrecord or parquet
Expand Down
20 changes: 16 additions & 4 deletions MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
from MaxText import tokenizer


def find_data_files(data_file_pattern):
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
return data_files


def get_datasets(
data_file_pattern,
data_file_type,
Expand All @@ -44,17 +50,23 @@ def get_datasets(
grain_worker_count,
):
"""Load dataset from array_record files for using with grain"""
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
if data_file_type == "arrayrecord":
dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files))
if ";" in data_file_pattern:
data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")])
assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match"
weights = [round(weight / sum(weights), 4) for weight in weights]
dataset_list = [grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns]
dataset = grain.MapDataset.mix(dataset_list, weights)
else:
data_files = find_data_files(data_file_pattern)
dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files))
if shuffle:
dataset = dataset.shuffle(seed=shuffle_seed)
dataset = dataset.repeat(num_epoch)
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
dataset = dataset.to_iter_dataset()
elif data_file_type == "parquet":
data_files = find_data_files(data_file_pattern)
dataset = grain.MapDataset.source(data_files)
if shuffle:
dataset = dataset.shuffle(seed=shuffle_seed)
Expand Down
2 changes: 1 addition & 1 deletion MaxText/input_pipeline/_input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def map(self, element):

@dataclasses.dataclass
class Rekey(grain.MapTransform):
"""Rname keys according to a mappign dict"""
"""Rename keys according to a mapping dict"""

def __init__(self, mapping_dict, keep_old_keys=False):
self.mapping_dict = mapping_dict
Expand Down
37 changes: 37 additions & 0 deletions MaxText/tests/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,43 @@ def get_first_batch(iterator):
self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all())


class GrainArrayRecordProcessingWithMultiSourceBlendingTest(GrainArrayRecordProcessingTest):
def setUp(self):
super().setUp()
temp_dir = tempfile.gettempdir()
# We use the same dataset for testing, but you can use different datasets by changing the file patterns.
grain_train_files = [
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3",
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7",
]
grain_train_files = ";".join(grain_train_files)
self.config = pyconfig.initialize(
[sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")],
per_device_batch_size=1,
run_name="test",
mesh_axes=["data"],
logical_axis_rules=[["batch", "data"]],
data_sharding=["data"],
base_output_directory="gs://max-experiments/",
dataset_type="grain",
grain_train_files=grain_train_files,
tokenizer_path=os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer"),
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
self.config.data_sharding,
self.config.global_batch_size_to_load,
self.config.global_batch_size_to_train_on,
self.config.max_target_length,
self.mesh,
)
self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)




class GrainParquetProcessingTest(unittest.TestCase):

@classmethod
Expand Down
15 changes: 13 additions & 2 deletions getting_started/Data_Input_Pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,18 @@ bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$BUCKET_NAME MOUNT_PATH=$MOUNT_PATH [FI
```
3. Set `dataset_type=grain` and set `grain_train_files` to match the ArrayRecord files via a local path since the bucket has been mounted.
4. Tune `grain_worker_count` for performance. This parameter controls the number of child process used by Grain (more details in [behind_the_scene](https://github.com/google/grain/blob/main/docs/behind_the_scenes.md), [code](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, please check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/setup_gcsfuse.sh) to avoid gcsfuse throttling.
5. Example command:

5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as separator and colon (:) for weights. The weights will be automatically normalized to sum to 1.0. For example:
```
# Blend two data sources with 30% from first source and 70% from second source
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:0.3;/tmp/gcsfuse/dataset2.array_record*:0.7

# Blend three data sources with equal weights (will be normalized to 0.33 each)
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:1;/tmp/gcsfuse/dataset2.array_record*:1;/tmp/gcsfuse/dataset3.array_record*:1
```
Note: When using multiple data sources, only ArrayRecord format is supported.

6. Example command:
```
bash setup_gcsfuse.sh \
DATASET_GCS_BUCKET=maxtext-dataset \
Expand All @@ -114,7 +125,7 @@ grain_file_type=arrayrecord \
grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \
grain_worker_count=2
```
6. Using validation set for eval
7. Using validation set for eval
When setting eval_interval > 0, eval will be run with a specified eval dataset. Example config:
```
eval_interval: 10000
Expand Down
Loading