2323 InternVisionModel ,
2424 InternVLChatConfig ,
2525 InternVLChatModel )
26- from internvl .patch import (replace_llama2_attn_with_flash_attn ,
27- replace_llama_rmsnorm_with_fused_rmsnorm )
26+ from internvl .patch import (pad_data_collator ,
27+ replace_llama2_attn_with_flash_attn ,
28+ replace_llama_rmsnorm_with_fused_rmsnorm ,
29+ replace_train_sampler )
2830from internvl .train .dataset import (TCSLoader , WeightedConcatDataset ,
2931 build_transform )
3032from PIL import Image , ImageFile , PngImagePlugin
3941# Upgrade transformers to v4.36.2, we don't need it anymore
4042# replace_llama2_attn_with_flash_attn()
4143replace_llama_rmsnorm_with_fused_rmsnorm ()
44+ replace_train_sampler ()
4245
4346try :
4447 from petrel_client .client import Client
@@ -182,6 +185,7 @@ def preprocess(
182185 tokenizer : transformers .PreTrainedTokenizer ,
183186 num_image_token : int ,
184187 text_only : bool = False ,
188+ group_by_length : bool = False ,
185189) -> Dict :
186190 conv = get_conv_template (template_name )
187191 roles = {'human' : conv .roles [0 ], 'gpt' : conv .roles [1 ]}
@@ -213,7 +217,7 @@ def preprocess(
213217 input_ids = tokenizer (
214218 conversations ,
215219 return_tensors = 'pt' ,
216- padding = 'max_length' ,
220+ padding = False if group_by_length else 'max_length' ,
217221 max_length = tokenizer .model_max_length ,
218222 truncation = True ,
219223 ).input_ids
@@ -283,6 +287,7 @@ def preprocess_mpt(
283287 tokenizer : transformers .PreTrainedTokenizer ,
284288 num_image_token : int ,
285289 text_only : bool = False ,
290+ group_by_length : bool = False ,
286291) -> Dict :
287292 conv = get_conv_template (template_name )
288293 roles = {'human' : conv .roles [0 ], 'gpt' : conv .roles [1 ]}
@@ -314,7 +319,7 @@ def preprocess_mpt(
314319 input_ids = tokenizer (
315320 conversations ,
316321 return_tensors = 'pt' ,
317- padding = 'max_length' ,
322+ padding = False if group_by_length else 'max_length' ,
318323 max_length = tokenizer .model_max_length ,
319324 truncation = True ,
320325 ).input_ids
@@ -368,7 +373,7 @@ class LazySupervisedDataset(Dataset):
368373 """Dataset for supervised fine-tuning."""
369374
370375 def __init__ (self , template_name , meta , tokenizer , tcs_loader , num_image_token ,
371- image_size = 224 , is_train = True , pad2square = False ):
376+ image_size = 224 , is_train = True , pad2square = False , group_by_length = False ):
372377 super (LazySupervisedDataset , self ).__init__ ()
373378 self .tokenizer = tokenizer
374379 self .template_name = template_name
@@ -384,6 +389,21 @@ def __init__(self, template_name, meta, tokenizer, tcs_loader, num_image_token,
384389 self .root = meta ['root' ]
385390 self .cached_data_dict = {}
386391 self .tcs_loader = tcs_loader
392+ self .group_by_length = group_by_length
393+ if self .group_by_length :
394+ self .conv2length = {}
395+ self .length = []
396+ for data_item in self .raw_data :
397+ conversations = '' .join (data_item .split ('conversations' )[1 :])
398+ str_length = len (conversations )
399+ if str_length not in self .conv2length :
400+ token_length = tokenizer (
401+ conversations , return_tensors = 'pt' , padding = False , truncation = False ,
402+ ).input_ids .size (1 )
403+ self .conv2length [str_length ] = token_length
404+ else :
405+ token_length = self .conv2length [str_length ]
406+ self .length .append (token_length )
387407
388408 def __len__ (self ):
389409 return len (self .raw_data )
@@ -405,7 +425,7 @@ def multi_modal_get_item(self, data_item):
405425 else :
406426 preprocess_function = preprocess
407427 ret = preprocess_function (self .template_name , [deepcopy (data_item ['conversations' ])],
408- self .tokenizer , self .num_image_token )
428+ self .tokenizer , self .num_image_token , group_by_length = self . group_by_length )
409429 ret = dict (
410430 input_ids = ret ['input_ids' ][0 ],
411431 labels = ret ['labels' ][0 ],
@@ -425,7 +445,8 @@ def pure_text_get_item(self, data_item):
425445 else :
426446 preprocess_function = preprocess
427447 ret = preprocess_function (self .template_name , [deepcopy (data_item ['conversations' ])],
428- self .tokenizer , self .num_image_token , text_only = True )
448+ self .tokenizer , self .num_image_token , text_only = True ,
449+ group_by_length = self .group_by_length )
429450 ret = dict (
430451 input_ids = ret ['input_ids' ][0 ],
431452 labels = ret ['labels' ][0 ],
@@ -455,7 +476,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
455476 return ret
456477
457478
458- def build_datasets (data_args , tokenizer , tcs_loader , model ):
479+ def build_datasets (data_args , tokenizer , tcs_loader , model , group_by_length = False ):
459480 datasets = []
460481 lengths = []
461482 ds_collections = json .loads (open (data_args .meta_path ).read ())
@@ -469,7 +490,8 @@ def build_datasets(data_args, tokenizer, tcs_loader, model):
469490 num_image_token = model .num_image_token ,
470491 image_size = data_args .force_image_size ,
471492 is_train = ds_collections [ds_name ]['data_augment' ],
472- pad2square = data_args .pad2square
493+ pad2square = data_args .pad2square ,
494+ group_by_length = group_by_length
473495 )
474496 except Exception :
475497 logger .info (f'Error in loading dataset: { ds_name } ' )
@@ -623,7 +645,8 @@ def main():
623645 if model_args .grad_checkpoint :
624646 model .language_model ._set_gradient_checkpointing ()
625647
626- train_dataset = build_datasets (data_args , tokenizer , tcs_loader , model )
648+ train_dataset = build_datasets (data_args , tokenizer , tcs_loader , model ,
649+ group_by_length = training_args .group_by_length )
627650
628651 def _freeze_params (module ):
629652 for param in module .parameters ():
@@ -672,7 +695,7 @@ def _freeze_params(module):
672695 train_dataset = train_dataset if training_args .do_train else None ,
673696 eval_dataset = None ,
674697 tokenizer = tokenizer ,
675- data_collator = default_data_collator ,
698+ data_collator = default_data_collator if not training_args . group_by_length else pad_data_collator ,
676699 )
677700
678701 # Training
0 commit comments