Skip to content

Commit 0e37920

Browse files
authored
fix(gemma3): update gemma3 multimodal forward implementation (#787)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Fix #786, #774. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <[email protected]>
1 parent ecdf6de commit 0e37920

File tree

2 files changed

+50
-98
lines changed

2 files changed

+50
-98
lines changed

src/liger_kernel/transformers/model/gemma3.py

Lines changed: 49 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from typing import List
21
from typing import Optional
32
from typing import Tuple
43
from typing import Union
@@ -10,17 +9,14 @@
109
from transformers.cache_utils import HybridCache
1110
from transformers.modeling_outputs import CausalLMOutputWithPast
1211
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
13-
from transformers.utils import is_torchdynamo_compiling
1412
from transformers.utils import logging
15-
from transformers.utils.deprecation import deprecate_kwarg
1613

1714
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1815
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
1916

2017
logger = logging.get_logger(__name__)
2118

2219

23-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
2420
def causal_forward(
2521
self,
2622
input_ids: torch.LongTensor = None,
@@ -139,14 +135,13 @@ def causal_forward(
139135
)
140136

141137

142-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
143138
def multimodal_forward(
144139
self,
145140
input_ids: torch.LongTensor = None,
146141
pixel_values: torch.FloatTensor = None,
147142
attention_mask: Optional[torch.Tensor] = None,
148143
position_ids: Optional[torch.LongTensor] = None,
149-
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
144+
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
150145
token_type_ids: Optional[torch.LongTensor] = None,
151146
cache_position: Optional[torch.LongTensor] = None,
152147
inputs_embeds: Optional[torch.FloatTensor] = None,
@@ -158,21 +153,12 @@ def multimodal_forward(
158153
logits_to_keep: Union[int, torch.Tensor] = 0,
159154
skip_logits: Optional[bool] = None,
160155
**lm_kwargs,
161-
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
156+
) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
162157
r"""
163-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
164-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
165-
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
166-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
167-
168-
logits_to_keep (`int` or `torch.Tensor`, *optional*):
169-
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
170-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
171-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
172-
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
173-
This is useful when using packed tensor format (single dimension for batch and sequence length).
174-
175-
Returns:
158+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
159+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
160+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
161+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
176162
177163
Example:
178164
@@ -181,111 +167,76 @@ def multimodal_forward(
181167
>>> import requests
182168
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
183169
184-
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
185-
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
186-
187-
>>> prompt = "answer en Where is the cow standing?"
188-
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
189-
>>> image = Image.open(requests.get(url, stream=True).raw)
190-
191-
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
192-
170+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
171+
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
172+
173+
>>> messages = [
174+
... {
175+
... "role": "system",
176+
... "content": [
177+
... {"type": "text", "text": "You are a helpful assistant."}
178+
... ]
179+
... },
180+
... {
181+
... "role": "user", "content": [
182+
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
183+
... {"type": "text", "text": "Where is the cat standing?"},
184+
... ]
185+
... },
186+
... ]
187+
188+
>>> inputs = processor.apply_chat_template(
189+
... messages,
190+
... tokenize=True,
191+
... return_dict=True,
192+
... return_tensors="pt",
193+
... add_generation_prompt=True
194+
... )
193195
>>> # Generate
194-
>>> generate_ids = model.generate(**inputs, max_length=30)
196+
>>> generate_ids = model.generate(**inputs)
195197
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196-
"answer en Where is the cow standing?\nbeach"
197-
```"""
198-
199-
if (input_ids is None) ^ (inputs_embeds is not None):
200-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
198+
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
199+
```
200+
"""
201201

202202
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
203203
output_hidden_states = (
204204
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205205
)
206206
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207207

208-
is_training = token_type_ids is not None and labels is not None
209-
210-
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
211-
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
212-
special_image_mask = input_ids == self.config.image_token_index
213-
llm_input_ids = input_ids.clone()
214-
llm_input_ids[special_image_mask] = 0
215-
else:
216-
llm_input_ids = input_ids
217-
218-
if inputs_embeds is None:
219-
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
220-
221-
if cache_position is None:
222-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
223-
cache_position = torch.arange(
224-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
225-
)
226-
227-
if position_ids is None:
228-
position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
229-
230-
# Merge text and images
231-
if pixel_values is not None:
232-
image_features = self.get_image_features(pixel_values)
233-
234-
if input_ids is None:
235-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
236-
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
237-
)
238-
else:
239-
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
240-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
241-
242-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
243-
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
244-
raise ValueError(
245-
f"Number of images does not match number of special image tokens in the input text. "
246-
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
247-
"tokens from image embeddings."
248-
)
249-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
250-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
251-
252-
# mask out pad-token-ids in labels for BC
253-
if labels is not None and self.pad_token_id in labels:
254-
logger.warning_once(
255-
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
256-
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
257-
)
258-
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
259-
260-
causal_mask = self._update_causal_mask(
261-
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
262-
)
263-
outputs = self.language_model.model(
264-
attention_mask=causal_mask,
208+
outputs = self.model(
209+
input_ids=input_ids,
210+
pixel_values=pixel_values,
211+
token_type_ids=token_type_ids,
212+
attention_mask=attention_mask,
265213
position_ids=position_ids,
266214
past_key_values=past_key_values,
267215
inputs_embeds=inputs_embeds,
268216
use_cache=use_cache,
217+
labels=labels,
269218
output_attentions=output_attentions,
270219
output_hidden_states=output_hidden_states,
271220
return_dict=return_dict,
272221
cache_position=cache_position,
273-
logits_to_keep=logits_to_keep,
274222
**lm_kwargs,
275223
)
276224

277225
hidden_states = outputs[0]
226+
227+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
228+
kept_hidden_states = hidden_states[:, slice_indices, :]
229+
278230
loss = None
279231
logits = None
280-
281232
if skip_logits and labels is None:
282233
raise ValueError("skip_logits is True, but labels is None")
283234

284235
if skip_logits is None:
285236
skip_logits = self.training and (labels is not None)
286237

287238
if skip_logits:
288-
shift_hidden_states = hidden_states[..., :-1, :]
239+
shift_hidden_states = kept_hidden_states[..., :-1, :]
289240
shift_labels = labels[..., 1:]
290241

291242
hidden_device = shift_hidden_states.device
@@ -306,7 +257,7 @@ def multimodal_forward(
306257
lce = LigerFusedLinearCrossEntropyLoss()
307258
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
308259
else:
309-
logits = self.language_model.lm_head(hidden_states)
260+
logits = self.lm_head(kept_hidden_states)
310261
if labels is not None:
311262
# Upcast to float if we need to compute the loss to avoid potential precision issues
312263
logits = logits.float()
@@ -327,6 +278,7 @@ def multimodal_forward(
327278
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
328279
flat_labels = shift_labels.view(-1).to(shift_logits.device)
329280
loss = loss_fct(flat_logits, flat_labels)
281+
330282
if not return_dict:
331283
output = (logits,) + outputs[1:]
332284
return (loss,) + output if loss is not None else output
@@ -337,5 +289,5 @@ def multimodal_forward(
337289
past_key_values=outputs.past_key_values,
338290
hidden_states=outputs.hidden_states,
339291
attentions=outputs.attentions,
340-
image_hidden_states=image_features if pixel_values is not None else None,
292+
image_hidden_states=outputs.image_hidden_states,
341293
)

test/convergence/bf16/test_mini_models_multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ def run_mini_model_multimodal(
10221022
5e-2,
10231023
5e-2,
10241024
1e-1,
1025-
1e-1,
1025+
1e-2,
10261026
1e-2,
10271027
1e-2,
10281028
marks=[

0 commit comments

Comments
 (0)