@@ -466,6 +466,73 @@ def get_mm_inputs(
466466 return mm_inputs
467467
468468
469+ @dataclass
470+ class Llama4Plugin (BasePlugin ):
471+ @override
472+ def process_messages (
473+ self ,
474+ messages : list [dict [str , str ]],
475+ images : list ["ImageInput" ],
476+ videos : list ["VideoInput" ],
477+ audios : list ["AudioInput" ],
478+ processor : Optional ["MMProcessor" ],
479+ ) -> list [dict [str , str ]]:
480+ self ._validate_input (processor , images , videos , audios )
481+ if self .expand_mm_tokens :
482+ mm_inputs = self ._get_mm_inputs (images , videos , audios , processor )
483+ if "pixel_values" in mm_inputs :
484+ image_height , image_width = mm_inputs ["pixel_values" ][0 ].shape [- 2 :]
485+ num_patches_per_chunk = int (
486+ (image_height // processor .patch_size )
487+ * (image_width // processor .patch_size )
488+ // processor .downsample_ratio
489+ )
490+ aspect_ratios = mm_inputs .pop ("aspect_ratios" )
491+
492+ num_image_tokens = 0
493+ messages = deepcopy (messages )
494+ for message in messages :
495+ content = message ["content" ]
496+ placeholder_count = content .count (IMAGE_PLACEHOLDER )
497+ if self .expand_mm_tokens :
498+ prompt_splits = content .split (IMAGE_PLACEHOLDER )
499+ new_content = []
500+ for local_image_index , split_part in enumerate (prompt_splits ):
501+ new_content .append (split_part )
502+ if local_image_index < placeholder_count :
503+ tokens_for_this_image = processor ._prompt_split_image (
504+ aspect_ratios [num_image_tokens ], num_patches_per_chunk
505+ )
506+ num_image_tokens += 1
507+ new_content .append (tokens_for_this_image )
508+
509+ content = "" .join (new_content )
510+
511+ message ["content" ] = content
512+
513+ if len (images ) != num_image_tokens :
514+ raise ValueError (f"The number of images does not match the number of { IMAGE_PLACEHOLDER } tokens." )
515+
516+ return messages
517+
518+ @override
519+ def get_mm_inputs (
520+ self ,
521+ images : list ["ImageInput" ],
522+ videos : list ["VideoInput" ],
523+ audios : list ["AudioInput" ],
524+ imglens : list [int ],
525+ vidlens : list [int ],
526+ audlens : list [int ],
527+ batch_ids : list [list [int ]],
528+ processor : Optional ["MMProcessor" ],
529+ ) -> dict [str , Union [list [int ], "torch.Tensor" ]]:
530+ self ._validate_input (processor , images , videos , audios )
531+ mm_inputs = self ._get_mm_inputs (images , videos , audios , processor )
532+ mm_inputs .pop ("aspect_ratios" , None )
533+ return mm_inputs
534+
535+
469536@dataclass
470537class LlavaPlugin (BasePlugin ):
471538 @override
@@ -1485,6 +1552,7 @@ def process_messages(
14851552PLUGINS = {
14861553 "base" : BasePlugin ,
14871554 "gemma3" : Gemma3Plugin ,
1555+ "llama4" : Llama4Plugin ,
14881556 "llava" : LlavaPlugin ,
14891557 "llava_next" : LlavaNextPlugin ,
14901558 "llava_next_video" : LlavaNextVideoPlugin ,
0 commit comments