diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 71a47a02..1e36a21b 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -4,7 +4,7 @@ import functools import time from functools import partial -from typing import Callable, Iterable, Union +from typing import Callable, Iterable, Optional, Union import torch import torch.distributed as dist @@ -201,44 +201,37 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): @llm_timeout(func_name="get_train_data_loader") -def get_train_data_loader( - num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None -): +def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[Callable] = None): """ Generate and return the training data loader. Args: num_worker (:class:`int`): number of subprocesses used for dataloader. dataset_generate_func (:class:`Callable`, optional): generate function for dataset. - train_sampler (:class:`torch.utils.data.sampler`, optional): dataset sampler for training dataloader. - train_collate_fn (:class:`Callable`, optional): collate function for training dataloader. Returns: A tuple of (train_dl, dataset_types). """ # Get the dataset types - dataset_types = None data_cfg = gpc.config.data - - # Get the sample weight dictionary train_folder = data_cfg.train_folder dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) - if not train_folder: - dataset_types = ["en", "cn", "code"] - train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len) - if data_cfg.pack_sample_into_one: - train_ds = PackedDatasetWithoutCuSeqlen( - train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length - ) - else: - train_ds = PackedDataset( - train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length - ) + if dataset_generate_func is not None: + train_ds, train_sampler, train_collate_fn = dataset_generate_func() else: - if dataset_generate_func is not None: - train_ds = dataset_generate_func() + if train_folder is None: + dataset_types = ["en", "cn", "code"] + train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len) + if data_cfg.pack_sample_into_one: + train_ds = PackedDatasetWithoutCuSeqlen( + train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length + ) + else: + train_ds = PackedDataset( + train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length + ) else: train_ds = get_packed_dataset_without_short_length( folder=data_cfg.train_folder, @@ -249,11 +242,6 @@ def get_train_data_loader( min_length_dict=data_cfg.get("min_length_dict", {}), pack_into_one_sample=data_cfg.pack_sample_into_one, ) - - if dataset_generate_func is None or not train_folder: - # partition already completed - assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset)) - # Create the training dataset sampler train_sampler = StaticBatchSampler( train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds], batch_size=data_cfg.micro_num, @@ -264,8 +252,6 @@ def get_train_data_loader( data_rank=gpc.get_local_rank(ParallelMode.DATA), data_world_size=gpc.get_world_size(ParallelMode.DATA), ) - - if dataset_generate_func is None or not train_folder: train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length) # Create the training data loader