1- from typing import List
21from typing import Optional
32from typing import Tuple
43from typing import Union
109from transformers .cache_utils import HybridCache
1110from transformers .modeling_outputs import CausalLMOutputWithPast
1211from transformers .models .gemma3 .modeling_gemma3 import Gemma3CausalLMOutputWithPast
13- from transformers .utils import is_torchdynamo_compiling
1412from transformers .utils import logging
15- from transformers .utils .deprecation import deprecate_kwarg
1613
1714from liger_kernel .transformers .fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1815from liger_kernel .transformers .model .loss_utils import LigerForCausalLMLoss
1916
2017logger = logging .get_logger (__name__ )
2118
2219
23- @deprecate_kwarg ("num_logits_to_keep" , version = "4.50" , new_name = "logits_to_keep" )
2420def causal_forward (
2521 self ,
2622 input_ids : torch .LongTensor = None ,
@@ -139,14 +135,13 @@ def causal_forward(
139135 )
140136
141137
142- @deprecate_kwarg ("num_logits_to_keep" , version = "4.50" , new_name = "logits_to_keep" )
143138def multimodal_forward (
144139 self ,
145140 input_ids : torch .LongTensor = None ,
146141 pixel_values : torch .FloatTensor = None ,
147142 attention_mask : Optional [torch .Tensor ] = None ,
148143 position_ids : Optional [torch .LongTensor ] = None ,
149- past_key_values : Optional [Union [List [torch .FloatTensor ], Cache ]] = None ,
144+ past_key_values : Optional [Union [list [torch .FloatTensor ], Cache ]] = None ,
150145 token_type_ids : Optional [torch .LongTensor ] = None ,
151146 cache_position : Optional [torch .LongTensor ] = None ,
152147 inputs_embeds : Optional [torch .FloatTensor ] = None ,
@@ -158,21 +153,12 @@ def multimodal_forward(
158153 logits_to_keep : Union [int , torch .Tensor ] = 0 ,
159154 skip_logits : Optional [bool ] = None ,
160155 ** lm_kwargs ,
161- ) -> Union [Tuple , Gemma3CausalLMOutputWithPast ]:
156+ ) -> Union [tuple , Gemma3CausalLMOutputWithPast ]:
162157 r"""
163- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
164- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
165- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
166- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
167-
168- logits_to_keep (`int` or `torch.Tensor`, *optional*):
169- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
170- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
171- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
172- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
173- This is useful when using packed tensor format (single dimension for batch and sequence length).
174-
175- Returns:
158+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
159+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
160+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
161+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
176162
177163 Example:
178164
@@ -181,111 +167,76 @@ def multimodal_forward(
181167 >>> import requests
182168 >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
183169
184- >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
185- >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
186-
187- >>> prompt = "answer en Where is the cow standing?"
188- >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
189- >>> image = Image.open(requests.get(url, stream=True).raw)
190-
191- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
192-
170+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
171+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
172+
173+ >>> messages = [
174+ ... {
175+ ... "role": "system",
176+ ... "content": [
177+ ... {"type": "text", "text": "You are a helpful assistant."}
178+ ... ]
179+ ... },
180+ ... {
181+ ... "role": "user", "content": [
182+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
183+ ... {"type": "text", "text": "Where is the cat standing?"},
184+ ... ]
185+ ... },
186+ ... ]
187+
188+ >>> inputs = processor.apply_chat_template(
189+ ... messages,
190+ ... tokenize=True,
191+ ... return_dict=True,
192+ ... return_tensors="pt",
193+ ... add_generation_prompt=True
194+ ... )
193195 >>> # Generate
194- >>> generate_ids = model.generate(**inputs, max_length=30 )
196+ >>> generate_ids = model.generate(**inputs)
195197 >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196- "answer en Where is the cow standing?\nbeach"
197- ```"""
198-
199- if (input_ids is None ) ^ (inputs_embeds is not None ):
200- raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
198+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
199+ ```
200+ """
201201
202202 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
203203 output_hidden_states = (
204204 output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
205205 )
206206 return_dict = return_dict if return_dict is not None else self .config .use_return_dict
207207
208- is_training = token_type_ids is not None and labels is not None
209-
210- # Replace image id woth PAD if the image token if OOV, to avoid index-errors
211- if input_ids is not None and self .config .image_token_index >= self .vocab_size :
212- special_image_mask = input_ids == self .config .image_token_index
213- llm_input_ids = input_ids .clone ()
214- llm_input_ids [special_image_mask ] = 0
215- else :
216- llm_input_ids = input_ids
217-
218- if inputs_embeds is None :
219- inputs_embeds = self .get_input_embeddings ()(llm_input_ids )
220-
221- if cache_position is None :
222- past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
223- cache_position = torch .arange (
224- past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
225- )
226-
227- if position_ids is None :
228- position_ids = cache_position .unsqueeze (0 ) + 1 # Gemma3 positions are 1-indexed
229-
230- # Merge text and images
231- if pixel_values is not None :
232- image_features = self .get_image_features (pixel_values )
233-
234- if input_ids is None :
235- special_image_mask = inputs_embeds == self .get_input_embeddings ()(
236- torch .tensor (self .config .image_token_index , dtype = torch .long , device = inputs_embeds .device )
237- )
238- else :
239- special_image_mask = (input_ids == self .config .image_token_index ).unsqueeze (- 1 )
240- special_image_mask = special_image_mask .expand_as (inputs_embeds ).to (inputs_embeds .device )
241-
242- if not is_torchdynamo_compiling () and inputs_embeds [special_image_mask ].numel () != image_features .numel ():
243- image_tokens_in_text = (special_image_mask ).sum (dim = 1 ).sum (dim = 0 )[0 ]
244- raise ValueError (
245- f"Number of images does not match number of special image tokens in the input text. "
246- f"Got { image_tokens_in_text } image tokens in the text but { image_features .shape [0 ] * image_features .shape [1 ]} "
247- "tokens from image embeddings."
248- )
249- image_features = image_features .to (inputs_embeds .device , inputs_embeds .dtype )
250- inputs_embeds = inputs_embeds .masked_scatter (special_image_mask , image_features )
251-
252- # mask out pad-token-ids in labels for BC
253- if labels is not None and self .pad_token_id in labels :
254- logger .warning_once (
255- "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
256- "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46." ,
257- )
258- labels = torch .where (input_ids == self .pad_token_id , self .config .ignore_index , labels )
259-
260- causal_mask = self ._update_causal_mask (
261- attention_mask , token_type_ids , past_key_values , cache_position , inputs_embeds , is_training
262- )
263- outputs = self .language_model .model (
264- attention_mask = causal_mask ,
208+ outputs = self .model (
209+ input_ids = input_ids ,
210+ pixel_values = pixel_values ,
211+ token_type_ids = token_type_ids ,
212+ attention_mask = attention_mask ,
265213 position_ids = position_ids ,
266214 past_key_values = past_key_values ,
267215 inputs_embeds = inputs_embeds ,
268216 use_cache = use_cache ,
217+ labels = labels ,
269218 output_attentions = output_attentions ,
270219 output_hidden_states = output_hidden_states ,
271220 return_dict = return_dict ,
272221 cache_position = cache_position ,
273- logits_to_keep = logits_to_keep ,
274222 ** lm_kwargs ,
275223 )
276224
277225 hidden_states = outputs [0 ]
226+
227+ slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
228+ kept_hidden_states = hidden_states [:, slice_indices , :]
229+
278230 loss = None
279231 logits = None
280-
281232 if skip_logits and labels is None :
282233 raise ValueError ("skip_logits is True, but labels is None" )
283234
284235 if skip_logits is None :
285236 skip_logits = self .training and (labels is not None )
286237
287238 if skip_logits :
288- shift_hidden_states = hidden_states [..., :- 1 , :]
239+ shift_hidden_states = kept_hidden_states [..., :- 1 , :]
289240 shift_labels = labels [..., 1 :]
290241
291242 hidden_device = shift_hidden_states .device
@@ -306,7 +257,7 @@ def multimodal_forward(
306257 lce = LigerFusedLinearCrossEntropyLoss ()
307258 loss = lce (self .language_model .lm_head .weight , shift_hidden_states , shift_labels )
308259 else :
309- logits = self .language_model . lm_head (hidden_states )
260+ logits = self .lm_head (kept_hidden_states )
310261 if labels is not None :
311262 # Upcast to float if we need to compute the loss to avoid potential precision issues
312263 logits = logits .float ()
@@ -327,6 +278,7 @@ def multimodal_forward(
327278 flat_logits = shift_logits .view (- 1 , self .config .text_config .vocab_size )
328279 flat_labels = shift_labels .view (- 1 ).to (shift_logits .device )
329280 loss = loss_fct (flat_logits , flat_labels )
281+
330282 if not return_dict :
331283 output = (logits ,) + outputs [1 :]
332284 return (loss ,) + output if loss is not None else output
@@ -337,5 +289,5 @@ def multimodal_forward(
337289 past_key_values = outputs .past_key_values ,
338290 hidden_states = outputs .hidden_states ,
339291 attentions = outputs .attentions ,
340- image_hidden_states = image_features if pixel_values is not None else None ,
292+ image_hidden_states = outputs . image_hidden_states ,
341293 )
0 commit comments