Skip to content

Commit f248529

Browse files
authored
fix Llava test-bwd failure (#639)
## Summary fix #638 <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence <details> <summary>convergence-test log</summary> ```python HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py ============================= test session starts ============================== platform linux -- Python 3.10.12, pytest-8.3.5, pluggy-1.5.0 rootdir: /root/workspace/jp-liger configfile: pyproject.toml ----------------------------- live log collection ------------------------------ INFO datasets:config.py:54 PyTorch version 2.5.1+cu121 available. collected 13 items test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:209 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 7%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:209 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:267 Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#35526 PASSED [ 15%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 23%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype3-1e-08-1e-05-0.005-1e-05-0.005-1e-05] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:855 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 30%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype4-1e-05-0.1-0.005-1e-05-0.005-1e-05] SKIPPED [ 38%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype5-1e-05-0.1-0.005-1e-05-0.005-1e-05] SKIPPED [ 46%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_olmo2-32-0.0001-dtype6-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 53%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype7-1e-08-1e-05-0.005-1e-05-0.005-1e-05] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:1067 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 61%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype8-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 69%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype9-1e-08-0.0001-0.005-1e-05-0.005-1e-05] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:598 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 76%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype10-1e-08-0.0001-0.005-1e-05-0.005-1e-05] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:598 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 84%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma2-32-0.0001-dtype11-1e-08-0.0001-0.005-1e-05-0.005-1e-05] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:672 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 92%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_granite3-32-0.0001-dtype12-1e-08-0.0001-0.005-1e-05-0.005-1e-05] SKIPPED [100%] ============== 8 passed, 5 skipped, 1 warning in 69.42s (0:01:09) ============== HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_multimodal.py ============================= test session starts ============================== platform linux -- Python 3.10.12, pytest-8.3.5, pluggy-1.5.0 rootdir: /root/workspace/jp-liger configfile: pyproject.toml ----------------------------- live log collection ------------------------------ INFO datasets:config.py:54 PyTorch version 2.5.1+cu121 available. collected 6 items test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_vl-32-0.0001-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 16%] test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:209 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 33%] test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_5_vl-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 50%] test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_mllama-32-0.0001-dtype3-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 66%] test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma-32-0.0001-dtype4-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 83%] test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma2-32-0.0001-dtype5-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [100%] ================== 1 passed, 5 skipped, 2 warnings in 30.71s =================== HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_with_logits.py ============================= test session starts ============================== platform linux -- Python 3.10.12, pytest-8.3.5, pluggy-1.5.0 rootdir: /root/workspace/jp-liger configfile: pyproject.toml ----------------------------- live log collection ------------------------------ INFO datasets:config.py:54 PyTorch version 2.5.1+cu121 available. collected 13 items test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] PASSED [ 7%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:209 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 15%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 23%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_qwen2-32-0.0001-dtype3-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 30%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype4-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 38%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype5-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 46%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_olmo2-32-0.0001-dtype6-1e-08-1e-05-0.005-1e-05-0.005-1e-05] SKIPPED [ 53%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_phi3-32-0.0001-dtype7-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 61%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_mistral-32-0.0001-dtype8-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 69%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_gemma1-32-0.0001-dtype9-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 76%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype10-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 84%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_gemma2-32-0.0001-dtype11-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 92%] test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_granite3-32-0.0001-dtype12-1e-08-0.0001-0.005-1e-05-0.005-1e-05] SKIPPED [100%] ============== 8 passed, 5 skipped, 1 warning in 68.41s (0:01:08) ============== HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models.py ============================= test session starts ============================== platform linux -- Python 3.10.12, pytest-8.3.5, pluggy-1.5.0 rootdir: /root/workspace/jp-liger configfile: pyproject.toml ----------------------------- live log collection ------------------------------ INFO datasets:config.py:54 PyTorch version 2.5.1+cu121 available. collected 12 items test/convergence/bf16/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype0-0.001-0.01-0.1-0.01-0.01-0.01] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:209 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 8%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_llava-32-0.0001-dtype1-0.001-0.01-0.1-0.01-0.01-0.01] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:209 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:267 Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#35526 PASSED [ 16%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_granite3-32-0.0001-dtype2-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 25%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_mllama-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 33%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype4-0.001-0.01-0.1-0.01-0.01-0.01] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:855 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 41%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype5-0.001-0.05-0.1-0.01-0.01-0.01] SKIPPED [ 50%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype6-0.001-0.05-0.1-0.01-0.01-0.01] SKIPPED [ 58%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype7-0.001-0.01-0.1-0.01-0.01-0.01] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:1067 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 66%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype8-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 75%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_olmo2-32-0.0001-dtype9-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 83%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype10-0.001-0.01-0.1-0.01-0.01-0.01] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:598 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 91%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype11-0.001-0.01-0.1-0.01-0.01-0.01] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:598 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [100%] =================== 7 passed, 5 skipped, 1 warning in 46.95s =================== HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py ============================= test session starts ============================== platform linux -- Python 3.10.12, pytest-8.3.5, pluggy-1.5.0 rootdir: /root/workspace/jp-liger configfile: pyproject.toml ----------------------------- live log collection ------------------------------ INFO datasets:config.py:54 PyTorch version 2.5.1+cu121 available. collected 6 items test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_vl-32-0.0001-dtype0-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 16%] test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_llava-32-0.0001-dtype1-0.001-0.01-0.1-0.01-0.01-0.01] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:209 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 33%] test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_5_vl-32-0.0001-dtype2-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 50%] test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_mllama-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 66%] test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma-32-0.0001-dtype4-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 83%] test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma2-32-0.0001-dtype5-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [100%] ================== 1 passed, 5 skipped, 2 warnings in 19.27s =================== HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py ============================= test session starts ============================== platform linux -- Python 3.10.12, pytest-8.3.5, pluggy-1.5.0 rootdir: /root/workspace/jp-liger configfile: pyproject.toml ----------------------------- live log collection ------------------------------ INFO datasets:config.py:54 PyTorch version 2.5.1+cu121 available. collected 12 items test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_llama3-32-0.0001-dtype0-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 8%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_llava-32-0.0001-dtype1-0.001-0.01-0.1-0.01-0.01-0.01] -------------------------------- live log call --------------------------------- WARNING liger_kernel.transformers.monkey_patch:monkey_patch.py:209 Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191 PASSED [ 16%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_granite3-32-0.0001-dtype2-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 25%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_mllama-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 33%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_qwen2-32-0.0001-dtype4-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 41%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype5-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 50%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype6-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [ 58%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_phi3-32-0.0001-dtype7-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 66%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_mistral-32-0.0001-dtype8-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 75%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_gemma1-32-0.0001-dtype9-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 83%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype10-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 91%] test/convergence/bf16/test_mini_models_with_logits.py::test_mini_model[mini_olmo2-32-0.0001-dtype11-0.001-0.01-0.1-0.01-0.01-0.01] SKIPPED [100%] =================== 7 passed, 5 skipped, 1 warning in 50.33s =================== ``` </details> ## env ``` transformers 4.44.2 torch 2.5.1+cu121 torchaudio 2.5.1+cu121 torchvision 0.20.1+cu121 ```
1 parent f49da8e commit f248529

File tree

3 files changed

+22
-34
lines changed

3 files changed

+22
-34
lines changed

src/liger_kernel/transformers/model/llava.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
99
from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
1010
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
11-
from transformers.models.llava.modeling_llava import logger
1211
from transformers.utils import add_start_docstrings_to_model_forward
1312
from transformers.utils import is_torchdynamo_compiling
1413
from transformers.utils import replace_return_docstrings
@@ -34,8 +33,6 @@ def lce_forward_deprecated(
3433
output_attentions: Optional[bool] = None,
3534
output_hidden_states: Optional[bool] = None,
3635
return_dict: Optional[bool] = None,
37-
cache_position: Optional[torch.LongTensor] = None,
38-
num_logits_to_keep: int = 0,
3936
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
4037
r"""
4138
Args:
@@ -96,39 +93,32 @@ def lce_forward_deprecated(
9693
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
9794
)
9895

99-
legacy_processing = False
10096
if inputs_embeds is None:
97+
# 1. Extra the input embeddings
10198
inputs_embeds = self.get_input_embeddings()(input_ids)
10299

103-
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
104-
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
105-
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
106-
legacy_processing = (
107-
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
108-
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
109-
110-
image_features = None
111-
if pixel_values is not None:
112-
image_features = self.get_image_features(
113-
pixel_values=pixel_values,
114-
vision_feature_layer=vision_feature_layer,
115-
vision_feature_select_strategy=vision_feature_select_strategy,
116-
)
117-
118-
if legacy_processing and image_features is not None:
119-
logger.warning_once(
120-
"Expanding inputs for image tokens in LLaVa should be done in processing. "
121-
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
122-
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
123-
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
124-
)
125-
# prefill stage vs decoding stage (legacy behavior copied)
126-
if input_ids.shape[1] != 1:
100+
# 2. Merge text and images
101+
if pixel_values is not None and input_ids.shape[1] != 1:
102+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
103+
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
104+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
105+
106+
if vision_feature_select_strategy == "default":
107+
selected_image_feature = selected_image_feature[:, 1:]
108+
elif vision_feature_select_strategy == "full":
109+
selected_image_feature = selected_image_feature
110+
else:
111+
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
112+
113+
image_features = self.multi_modal_projector(selected_image_feature)
114+
inputs_embeds = inputs_embeds.to(image_features.dtype)
127115
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
128116
image_features, inputs_embeds, input_ids, attention_mask, labels
129117
)
130-
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
131-
else:
118+
119+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
120+
# generation with cache
121+
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
132122
# Retrieve the first layer to inspect the logits and mask out the hidden states
133123
# that are set to 0
134124
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
@@ -158,7 +148,6 @@ def lce_forward_deprecated(
158148

159149
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
160150
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
161-
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
162151

163152
# TODO: @raushan retain only the new behavior after v4.47
164153
elif image_features is not None:
@@ -184,8 +173,6 @@ def lce_forward_deprecated(
184173
output_attentions=output_attentions,
185174
output_hidden_states=output_hidden_states,
186175
return_dict=return_dict,
187-
cache_position=cache_position,
188-
num_logits_to_keep=num_logits_to_keep,
189176
)
190177
hidden_states = outputs[0]
191178

@@ -220,7 +207,6 @@ def lce_forward_deprecated(
220207
past_key_values=outputs.past_key_values,
221208
hidden_states=outputs.hidden_states,
222209
attentions=outputs.attentions,
223-
image_hidden_states=image_features if pixel_values is not None else None,
224210
)
225211

226212

test/convergence/bf16/test_mini_models_multimodal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def create_processor(model_name: str):
492492
)
493493

494494
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config)
495+
fast_tokenizer.model_input_names = ["input_ids", "attention_mask"]
495496
image_processor = CLIPImageProcessor(**image_processor_config)
496497

497498
return LlavaProcessor(**processor_config, image_processor=image_processor, tokenizer=fast_tokenizer)

test/convergence/fp32/test_mini_models_multimodal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def create_processor(model_name: str):
492492
)
493493

494494
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config)
495+
fast_tokenizer.model_input_names = ["input_ids", "attention_mask"]
495496
image_processor = CLIPImageProcessor(**image_processor_config)
496497

497498
return LlavaProcessor(**processor_config, image_processor=image_processor, tokenizer=fast_tokenizer)

0 commit comments

Comments
 (0)