@@ -36,6 +36,7 @@ def __init__(
3636 system_prompt : str | None = None ,
3737 processor_kwargs : Dict [str , Any ] | None = None ,
3838 generation_kwargs : Dict [str , Any ] | None = None ,
39+ image_key : str = "image" ,
3940 ):
4041 """Initialize the HuggingFace model wrapper.
4142
@@ -46,6 +47,7 @@ def __init__(
4647 system_prompt: System prompt to use.
4748 processor_kwargs: Additional processor arguments.
4849 generation_kwargs: Additional generation arguments.
50+ image_key: The key used for image inputs in the chat template.
4951 """
5052 super ().__init__ (system_prompt = system_prompt )
5153
@@ -54,6 +56,7 @@ def __init__(
5456 self .base_model_class = model_class
5557 self .processor_kwargs = processor_kwargs or {}
5658 self .generation_kwargs = generation_kwargs or {}
59+ self .image_key = image_key
5760
5861 self .processor = self .load_processor ()
5962 self .model = self .load_model ()
@@ -106,7 +109,7 @@ def format_inputs(self, batch: TextImageBatch | TextBatch) -> Dict[str, torch.Te
106109 }
107110
108111 if with_images :
109- processor_inputs ["image" ] = [[image ] for image in image_batch ]
112+ processor_inputs [self . image_key ] = [[image ] for image in image_batch ]
110113
111114 return self .processor (** processor_inputs ).to (self .model .device ) # type: ignore
112115
0 commit comments