Skip to content

Commit

Permalink
feat(train): update get_train_data_loader to make logic clearer (Inte…
Browse files Browse the repository at this point in the history
…rnLM#498)

* update get_train_data_loader

* update get_train_data_loader, del old doc

---------

Co-authored-by: YWMditto <[email protected]>
  • Loading branch information
YWMditto and YWMditto authored Nov 14, 2023
1 parent 2b984ff commit be5b9ea
Showing 1 changed file with 15 additions and 29 deletions.
44 changes: 15 additions & 29 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import time
from functools import partial
from typing import Callable, Iterable, Union
from typing import Callable, Iterable, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -201,44 +201,37 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):


@llm_timeout(func_name="get_train_data_loader")
def get_train_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
):
def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[Callable] = None):
"""
Generate and return the training data loader.
Args:
num_worker (:class:`int`): number of subprocesses used for dataloader.
dataset_generate_func (:class:`Callable`, optional): generate function for dataset.
train_sampler (:class:`torch.utils.data.sampler`, optional): dataset sampler for training dataloader.
train_collate_fn (:class:`Callable`, optional): collate function for training dataloader.
Returns:
A tuple of (train_dl, dataset_types).
"""

# Get the dataset types
dataset_types = None
data_cfg = gpc.config.data

# Get the sample weight dictionary
train_folder = data_cfg.train_folder
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())

if not train_folder:
dataset_types = ["en", "cn", "code"]
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
if dataset_generate_func is not None:
train_ds, train_sampler, train_collate_fn = dataset_generate_func()
else:
if dataset_generate_func is not None:
train_ds = dataset_generate_func()
if train_folder is None:
dataset_types = ["en", "cn", "code"]
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = get_packed_dataset_without_short_length(
folder=data_cfg.train_folder,
Expand All @@ -249,11 +242,6 @@ def get_train_data_loader(
min_length_dict=data_cfg.get("min_length_dict", {}),
pack_into_one_sample=data_cfg.pack_sample_into_one,
)

if dataset_generate_func is None or not train_folder:
# partition already completed
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
# Create the training dataset sampler
train_sampler = StaticBatchSampler(
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num,
Expand All @@ -264,8 +252,6 @@ def get_train_data_loader(
data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA),
)

if dataset_generate_func is None or not train_folder:
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)

# Create the training data loader
Expand Down

0 comments on commit be5b9ea

Please sign in to comment.