From e27465c8011998c051403b22b7ccceca15de37db Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Thu, 5 Dec 2024 18:07:33 +0200 Subject: [PATCH 1/2] Adaptive dynamic number of speculative tokens (#34156) * initial commit * update strategy * add tradeoff FPR TPR with cost * all probs * fix * fix * fix style * Update src/transformers/generation/configuration_utils.py shorter docstring Co-authored-by: Joao Gante * import guard * fix style * add is_sklearn_available condition * vectorizing to flatten the for-loop * fix style * disable adaptation for UAG * update doc * add TestAssistedCandidateGeneratorUpdateStrategy * fix style * protect import * fix style --------- Co-authored-by: Joao Gante --- docs/source/en/generation_strategies.md | 2 + .../generation/candidate_generator.py | 57 +++++++++ .../generation/configuration_utils.py | 4 +- tests/generation/test_utils.py | 116 +++++++++++++++++- 4 files changed, 177 insertions(+), 2 deletions(-) diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 380b39fe62a..47032a2a292 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -456,6 +456,8 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t ['Alice and Bob, a couple of friends of mine, who are both in the same office as'] ``` +We recommend to install `scikit-learn` library to enhance the candidate generation strategy and achieve additional speedup. + #### Universal Assisted Decoding Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers. diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 7cab88a4bc2..9a62b5709b5 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -19,6 +19,12 @@ import numpy as np import torch +from ..utils import is_sklearn_available + + +if is_sklearn_available(): + from sklearn.metrics import roc_curve + from ..cache_utils import DynamicCache from ..pytorch_utils import isin_mps_friendly from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor @@ -180,6 +186,14 @@ def __init__( # We need to roll back the cache in assisted generation, only DynamicCache is supported self.generation_config.cache_implementation = None + if ( + is_sklearn_available() + and self.assistant_model.generation_config.assistant_confidence_threshold + and type(self) is AssistedCandidateGenerator + ): + self.probs = [] + self.matches = [] + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. @@ -230,6 +244,17 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # 3. Update variables for the next round of candidate generation self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values + if ( + is_sklearn_available() + and self.assistant_model.generation_config.assistant_confidence_threshold + and type(self) is AssistedCandidateGenerator + ): + scores_tensor = torch.cat(assistant_output.scores, dim=0) + scores_softmax = torch.softmax(scores_tensor, dim=-1) + ids = assistant_output.sequences[-1, -len(assistant_output.scores) :] + p = scores_softmax[range(len(ids)), ids] + self.probs.extend(p.tolist()) + # 4. Prepare variables for output candidate_logits = torch.stack(assistant_output.scores, dim=1) candidate_ids = assistant_output.sequences @@ -261,6 +286,38 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F else: self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) + # The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. A cost of 25% is assigned to false positives and 75% to false negatives. + # This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG. + if ( + is_sklearn_available() + and self.assistant_model.generation_config.assistant_confidence_threshold + and type(self) is AssistedCandidateGenerator + ): + # update self.matches + self.matches.extend([1] * num_matches) + if len(self.probs) > len(self.matches): + self.matches.append(0) + + # update self.probs + excess_length = len(self.probs) - len(self.matches) + if excess_length > 0: + del self.probs[-excess_length:] + + if ( + len(self.probs) > 5 and {0, 1}.issubset(self.matches) + ): # require at least 5 samples to calculate the ROC curve and at least one positive and one negative sample + fpr, tpr, thresholds = roc_curve(self.matches, self.probs) + fnr = 1 - tpr + + # Calculate the cost for each threshold + costs = fpr + 3 * fnr + + # Find the threshold that minimizes the cost + optimal_threshold_index = np.argmin(costs) + best_threshold = thresholds[optimal_threshold_index] + + self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold + class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): """ diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 486cd2336c3..0a6fdd9fb51 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -353,7 +353,9 @@ class GenerationConfig(PushToHubMixin): assistant_confidence_threshold (`float`, *optional*, defaults to 0.4): The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ - (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead + (defined by `num_assistant_tokens`) is not yet reached. The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes, biased towards avoiding false negatives. + `assistant_confidence_threshold` value is persistent over multiple generation calls with the same assistant model. + It is an unsupervised version of the dynamic speculation lookahead from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models . prompt_lookup_num_tokens (`int`, *optional*): The number of tokens to be output as candidate tokens. diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 063e9a3da8f..12faeb8da92 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -92,9 +92,16 @@ WatermarkDetector, WatermarkingConfig, ) - from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers + from transformers.generation.candidate_generator import ( + AssistedCandidateGenerator, + AssistedCandidateGeneratorDifferentTokenizers, + ) from transformers.generation.utils import _speculative_sampling +from unittest.mock import patch + +from transformers.utils import is_sklearn_available + class GenerationTesterMixin: input_name = "input_ids" @@ -4312,3 +4319,110 @@ def test_no_new_tokens(self): self.assertEqual(discrep_length, 0) np.testing.assert_array_equal(new_tokens_only, np.array([[]])) np.testing.assert_array_equal(discrep_only, np.array([[]])) + + +class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase): + def setUp(self): + checkpoint = "EleutherAI/pythia-160m-deduped" + self.assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint) + self.assistant_model.generation_config.assistant_confidence_threshold = 0.4 + self.model_kwargs = {} + self.input_ids = torch.randint(1, 10, (1, 9)) + self.candidate_generator = AssistedCandidateGenerator( + input_ids=self.input_ids, + assistant_model=self.assistant_model, + generation_config=self.assistant_model.generation_config, + model_kwargs=self.model_kwargs, + ) + self.candidate_generator.probs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] + self.original_probs = self.candidate_generator.probs + self.original_threshold = self.assistant_model.generation_config.assistant_confidence_threshold + + def assert_no_sklearn(self): + with patch("transformers.utils.import_utils._sklearn_available", False): + self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) + self.assertEqual(self.candidate_generator.matches, self.original_matches) + self.assertEqual(self.candidate_generator.probs, self.original_probs) + self.assertEqual( + self.assistant_model.generation_config.assistant_confidence_threshold, self.original_threshold + ) + + @parameterized.expand([(is_sklearn_available(),), (False,)]) + def test_update_candidate_strategy_no_matches_short(self, sklearn_available): + print("test_update_candidate_strategy_no_matches_short") + self.original_matches = [] + self.candidate_generator.matches = self.original_matches + self.num_matches = 0 + + if sklearn_available: + self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) + self.assertEqual(self.candidate_generator.matches, [0]) + self.assertEqual(self.candidate_generator.probs, [0.9]) + self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) + else: + self.assert_no_sklearn() + + @parameterized.expand([(is_sklearn_available(),), (False,)]) + def test_update_candidate_strategy_with_mix_matches_3(self, sklearn_available): + self.original_matches = [1, 0, 1, 0, 1] + self.candidate_generator.matches = self.original_matches + self.num_matches = 3 + if sklearn_available: + self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) + self.assertEqual(self.candidate_generator.matches, [1, 0, 1, 0, 1, 1, 1, 1, 0]) + self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) + self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2) + else: + self.assert_no_sklearn() + + @parameterized.expand([(is_sklearn_available(),), (False,)]) + def test_update_candidate_strategy_with_matches_4(self, sklearn_available): + self.original_matches = [1, 1, 1, 1, 1] + self.candidate_generator.matches = self.original_matches + self.num_matches = 4 + if sklearn_available: + self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) + self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 1]) + self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) + self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) + else: + self.assert_no_sklearn() + + @parameterized.expand([(is_sklearn_available(),), (False,)]) + def test_update_candidate_strategy_with_matches_3(self, sklearn_available): + self.original_matches = [1, 1, 1, 1, 1] + self.candidate_generator.matches = self.original_matches + self.num_matches = 3 + if sklearn_available: + self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) + self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 0]) + self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) + self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2) + else: + self.assert_no_sklearn() + + @parameterized.expand([(is_sklearn_available(),), (False,)]) + def test_update_candidate_strategy_with_matches_2(self, sklearn_available): + self.original_matches = [1, 1, 1, 1, 1] + self.candidate_generator.matches = self.original_matches + self.num_matches = 2 + if sklearn_available: + self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) + self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 0]) + self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]) + self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.3) + else: + self.assert_no_sklearn() + + @parameterized.expand([(is_sklearn_available(),), (False,)]) + def test_update_candidate_strategy_with_matches_1(self, sklearn_available): + self.original_matches = [1, 1, 1, 1, 1] + self.candidate_generator.matches = self.original_matches + self.num_matches = 1 + if sklearn_available: + self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) + self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 0]) + self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3]) + self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) + else: + self.assert_no_sklearn() From a5bb52847139bf6ad7489ac62a5fb6d0fa3d2ec6 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Thu, 5 Dec 2024 18:15:48 +0100 Subject: [PATCH 2/2] Fix signatures for processing kwargs (#35105) * add conversion script * remove pg2 refs * fixup style * small update * get correct scaling * add back missing bos * fix missing config keys * might revert this pos_embeddings * fixup 9b config * fix 9b * fixup 9b conversion for good + add back num_hidden_layers * add correct query scaling for 2b, 9b, 27b * fixup 27b conversion * Additional variant: 27b-896 * Use CPU for conversion to reduce GPU RAM requirements * fix causal mask generation + formatting * fix in-training causal mask generation edge case * trigger CI * update config * update config * update config * update config * update config * update config * update config * update config * update config * move conversion file to main model dir * handle multi-images + bos token * address comments for input ids * revert ci fixes * [run-slow] paligemma * fix * [run-slow] paligemma * skip end 2 end * [run-slow] paligemma --------- Co-authored-by: Pedro Cuenca Co-authored-by: ydshieh --- .../convert_paligemma2_weights_to_hf.py | 415 ++++++++++++++++++ .../models/paligemma/modeling_paligemma.py | 37 +- .../models/paligemma/processing_paligemma.py | 15 +- .../paligemma/test_modeling_paligemma.py | 5 + .../paligemma/test_processor_paligemma.py | 6 +- 5 files changed, 459 insertions(+), 19 deletions(-) create mode 100644 src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py diff --git a/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py b/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py new file mode 100644 index 00000000000..df869fcefb2 --- /dev/null +++ b/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py @@ -0,0 +1,415 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert PaliGemma2 checkpoints from the original repository.""" + +import argparse +import collections + +import jax.numpy as jnp +import ml_dtypes +import numpy as np +import torch + +from transformers import ( + AutoTokenizer, + Gemma2Config, + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + PaliGemmaProcessor, + SiglipImageProcessor, +) +from transformers.tokenization_utils_base import AddedToken +from transformers.utils import logging + + +device = "cpu" + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# TODO add sequence length variations here + +PALIGEMMA2_VARIANTS = ["2b-224", "2b-448", "2b-896", "9b-224", "9b-448", "9b-896", "27b-224", "27b-448", "27b-896"] +VARIANT_CONFIGS = { + "2b": { + "num_positions": 256, + "hidden_size": 2304, + "num_hidden_layers": 26, + "intermediate_size": 9216, + "num_key_value_heads": 4, + "num_attention_heads": 8, + "head_dim": 256, + "query_pre_attn_scalar": 256, + }, + "9b": { + "num_positions": 1024, + "hidden_size": 3584, + "num_hidden_layers": 42, + "intermediate_size": 14336, + "num_key_value_heads": 8, + "num_attention_heads": 16, + "head_dim": 256, + "query_pre_attn_scalar": 256, + }, + "27b": { + "num_positions": 4096, + "hidden_size": 4608, + "num_hidden_layers": 46, + "intermediate_size": 36864, + "num_key_value_heads": 16, + "num_attention_heads": 32, + "head_dim": 128, + "query_pre_attn_scalar": 4608 // 32, # scaling is different for the 28b + }, +} + +DTYPES = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} + + +def get_paligemma2_config(variant: str, precision: str): + config = { + "image_token_index": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + base_variant = variant.split("-")[0] + + if variant in PALIGEMMA2_VARIANTS: + image_size = int(variant.split("-")[1]) + variant_config = VARIANT_CONFIGS[base_variant] + patch_size = 14 + num_image_tokens = (image_size**2) // (patch_size**2) + config["projection_dim"] = variant_config["hidden_size"] + config["image_token_index"] = 257152 + config["num_hidden_layers"] = variant_config["num_hidden_layers"] # For generate + text_config = Gemma2Config.from_pretrained("google/gemma-2-2b-it").to_dict() + sup_text_config = { + "model_type": "gemma2", + "vocab_size": 257152, + "num_hidden_layers": variant_config["num_hidden_layers"], + "num_key_value_heads": variant_config["num_key_value_heads"], + "head_dim": variant_config["head_dim"], + "torch_dtype": precision, + "hidden_size": variant_config["hidden_size"], + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": variant_config["num_attention_heads"], + "intermediate_size": variant_config["intermediate_size"], + "is_encoder_decoder": False, + "query_pre_attn_scalar": variant_config["query_pre_attn_scalar"], + } + text_config.update(sup_text_config) + + vision_config = { + "num_positions": variant_config["num_positions"], # not useful, to remove + "torch_dtype": precision, + "image_size": image_size, + "patch_size": patch_size, + "num_image_tokens": num_image_tokens, + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "projection_dim": variant_config["hidden_size"], + "hidden_act": "gelu_pytorch_tanh", + "vision_use_head": False, + } + final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) + else: + raise ValueError(f"Identifier {variant} not supported. Available: {PALIGEMMA2_VARIANTS}") + return final_config + + +def slice_state_dict(state_dict, config): + # fmt: off + # patch embeddings + state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop("img/embedding/kernel").transpose( + 3, 2, 0, 1 + ) + state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop("img/embedding/bias") + # positional embeddings + state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop("img/pos_embedding").reshape( + -1, config.vision_config.hidden_size + ) + + # extract vision layers to be sliced at index 0. There are 27 layers in the base model. + encoderblock_layernorm0_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/scale") + encoderblock_layernorm0_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/bias") + encoderblock_layernorm1_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/scale") + encoderblock_layernorm1_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/bias") + + encoderblock_mlp_dense0_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel") + encoderblock_mlp_dense0_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias") + encoderblock_mlp_dense1_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel") + encoderblock_mlp_dense1_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias") + + encoderblock_attention_0_key_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel") + encoderblock_attention_0_key_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias") + encoderblock_attention_0_value_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel") + encoderblock_attention_0_value_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias") + encoderblock_attention_0_query_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel") + encoderblock_attention_0_query_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias") + encoderblock_attention_0_out_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel") + encoderblock_attention_0_out_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias") + + for i in range(config.vision_config.num_hidden_layers): + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] + + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + + state_dict["vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop("img/Transformer/encoder_norm/scale").transpose() + state_dict["vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop("img/Transformer/encoder_norm/bias") + + # multimodal projector + + state_dict['multi_modal_projector.linear.weight'] = state_dict.pop("img/head/kernel").transpose() + state_dict['multi_modal_projector.linear.bias'] = state_dict.pop("img/head/bias") + + # text decoder (gemma) + + embedding_vector = state_dict.pop("llm/embedder/input_embedding") + state_dict["language_model.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 26 layers in gemma2-2b. + + llm_attention_attn_vec_einsum = state_dict.pop("llm/layers/attn/attn_vec_einsum/w") + # (26, 2, 4, 2304, 256) for 2b-224, 4 kv heads and 26 layers + llm_attention_kv_einsum = state_dict.pop("llm/layers/attn/kv_einsum/w") + llm_attention_q_einsum = state_dict.pop("llm/layers/attn/q_einsum/w") + llm_mlp_gating_einsum = state_dict.pop("llm/layers/mlp/gating_einsum") + llm_mlp_linear = state_dict.pop("llm/layers/mlp/linear") + # TODO verify correctness of layer norm loading + llm_input_layernorm = state_dict.pop("llm/layers/pre_attention_norm/scale") + llm_pre_feedforward_layernorm = state_dict.pop("llm/layers/pre_ffw_norm/scale") + + llm_post_attention_layernorm = state_dict.pop("llm/layers/post_attention_norm/scale") + llm_post_feedforward_layernorm = state_dict.pop("llm/layers/post_ffw_norm/scale") + + for i in range(config.text_config.num_hidden_layers): + # llm_attention_q_einsum[i].shape = (8, 2048, 256) + # q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + """ + q shape (8, 2304, 256) + k shape (4, 2304, 256) + v shape (4, 2304, 256) + o shape (8, 256, 2304) + + """ + q_transpose = (0, 2, 1) + k_transpose = (0, 2, 1) + v_transpose = (0, 2, 1) + o_transpose = (2, 0, 1) + + q_weight_matrices = llm_attention_q_einsum[i].transpose(*q_transpose) + q_proj_weight_reshaped = q_weight_matrices + q_proj_weight_reshaped = q_proj_weight_reshaped.reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + state_dict[f"language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + # Shape: (4, 2304, 256) + k_weight_matrices = llm_attention_kv_einsum[i, 0].transpose(*k_transpose) + k_proj_weight_reshaped = k_weight_matrices.reshape( + config.text_config.num_key_value_heads * config.text_config.head_dim, + config.text_config.hidden_size + ) + state_dict[f"language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + # llm_attention_kv_einsum[i, 1].shape = (num_key_value_heads, hidden_size, head_dim) + v_weight_matrices = llm_attention_kv_einsum[i, 1].transpose(*v_transpose) # Shape: (4, 2304, 256) + v_proj_weight_reshaped = v_weight_matrices.reshape( + config.text_config.num_key_value_heads * config.text_config.head_dim, + config.text_config.hidden_size + ) + state_dict[f"language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2304) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(*o_transpose).reshape(config.text_config.hidden_size, config.text_config.num_attention_heads * config.text_config.head_dim) + state_dict[f"language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + state_dict[f"language_model.model.layers.{i}.pre_feedforward_layernorm.weight"] = llm_pre_feedforward_layernorm[i] + state_dict[f"language_model.model.layers.{i}.post_feedforward_layernorm.weight"] = llm_post_feedforward_layernorm[i] + state_dict["language_model.model.norm.weight"] = state_dict.pop("llm/final_norm/scale") + state_dict["language_model.lm_head.weight"] = embedding_vector # weights are tied. + [k for k in state_dict.keys() if not k.startswith('vision') and not k.startswith('language')] + # fmt: on + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + try: + if value.dtype == jnp.bfloat16: + value = jnp.array(value).astype(jnp.float32) + value = np.array(value) + state_dict[key] = torch.from_numpy(value).to(torch.bfloat16) + else: + state_dict[key] = torch.from_numpy(value) + except Exception as initial_exception: + raise ValueError(f"Conversion failed from jax weights with {initial_exception}. Check your inputs.") + return state_dict + + +def flatten_nested_dict(params, parent_key="", sep="/", precision: int = "float32"): + items = [] + + for k, v in params.items(): + k = k.removeprefix("params/") + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_nested_dict(v, parent_key=new_key, sep=sep, precision=precision).items()) + else: + if precision == "bfloat16": + try: + v = v.view(ml_dtypes.bfloat16) + except Exception as initial_exception: + raise ValueError(f"Conversion failed from bfloat16 with {initial_exception}, check your inputs.") + items.append((new_key, v)) + return dict(items) + + +@torch.no_grad() +def convert_paligemma2_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + variant: str, + precision: str, + do_convert_weights=False, +): + """ + Read checkpoints from flax npz files, rename/reshape, send result to state dict and verify logits if needed. + """ + config = get_paligemma2_config(variant, precision=precision) + if do_convert_weights: + tokenizer_id = "google/paligemma-3b-pt-224" # same tokenizer as paligemma 1 + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + image_token = AddedToken("", normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + + # tokenizer.padding_side = 'right' # uncomment for testing purposes only. + + image_processor = SiglipImageProcessor.from_pretrained("google/paligemma-3b-pt-224") + image_processor.size = {"width": config.vision_config.image_size, "height": config.vision_config.image_size} + image_processor.image_seq_length = config.vision_config.num_image_tokens + + processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer) + data = jnp.load(checkpoint_path) + state_dict = flatten_nested_dict(data, precision=precision) + del data + state_dict_transformers = slice_state_dict(state_dict, config) + del state_dict + del config.hidden_size # this key is unused + model = PaliGemmaForConditionalGeneration(config).to(device).eval() + model.load_state_dict(state_dict_transformers) + del state_dict_transformers + model.config.text_config._attn_implementation = "sdpa" + + # model expansion to get random embeds of image tokens + pad_shape = 64 # for performance reasons + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[257152:] = torch.stack( + tuple( + (dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[257152:].shape[0])) + ), + dim=0, + ) + model.language_model.lm_head.weight.data[257152:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[257152:].shape[0]))), + dim=0, + ) + # convert to needed precision + + model.to(DTYPES[precision]) + model.save_pretrained(pytorch_dump_folder_path, safe_serialization=True) + processor.save_pretrained(pytorch_dump_folder_path) + + else: + processor = PaliGemmaProcessor.from_pretrained(pytorch_dump_folder_path, do_rescale=False) + model = ( + PaliGemmaForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, attn_implementation="sdpa") + .to(device) + .eval() + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_path", + required=True, + type=str, + help="Path to the .npz checkpoint", + ) + + parser.add_argument( + "--pytorch_dump_folder_path", + required=True, + type=str, + help="Path to the output directory where model and processor will be saved.", + ) + + parser.add_argument( + "--precision", + choices=["float32", "bfloat16", "float16"], + type=str, + help="Precision identifier for model conversion - should match the base checkpoint precision.", + ) + + parser.add_argument( + "--variant", + default="2b-224", + choices=PALIGEMMA2_VARIANTS, + type=str, + help="String identifier of the paligemma2 variant to convert.", + ) + + parser.add_argument( + "--do_convert_weights", action="store_true", help="Whether or not to reload and convert the weights." + ) + + args = parser.parse_args() + convert_paligemma2_checkpoint( + checkpoint_path=args.checkpoint_path, + pytorch_dump_folder_path=args.pytorch_dump_folder_path, + variant=args.variant, + precision=args.precision, + do_convert_weights=args.do_convert_weights, + ) diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 63cfbb6e6a5..b4a231561ba 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -21,7 +21,7 @@ import torch.utils.checkpoint from torch import nn -from ...cache_utils import Cache, StaticCache +from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -341,7 +341,14 @@ def tie_weights(self): return self.language_model.tie_weights() def _update_causal_mask( - self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_ids=None, + inputs_embeds=None, + is_training: bool = False, ): if self.config.text_config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: @@ -349,11 +356,13 @@ def _update_causal_mask( return None using_static_cache = isinstance(past_key_values, StaticCache) - dtype = inputs_embeds.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = inputs_embeds.shape[1] + min_dtype = torch.finfo(self.dtype).min + inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -366,7 +375,7 @@ def _update_causal_mask( return attention_mask causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device ) # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below if sequence_length != 1: @@ -376,7 +385,7 @@ def _update_causal_mask( causal_mask[:, :sequence_length] = 0.0 causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] @@ -405,7 +414,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor): image_outputs = self.vision_tower(pixel_values) selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) - image_features = image_features / (self.config.hidden_size**0.5) + image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) @@ -516,9 +525,8 @@ def forward( labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training + attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training ) - outputs = self.language_model( attention_mask=causal_mask, position_ids=position_ids, @@ -579,6 +587,7 @@ def prepare_inputs_for_generation( token_type_ids=None, use_cache=True, num_logits_to_keep=None, + labels=None, **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling @@ -598,10 +607,14 @@ def prepare_inputs_for_generation( # position_ids in Paligemma are 1-indexed if model_inputs.get("position_ids") is not None: model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values - + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training + ) + model_inputs["attention_mask"] = causal_mask return model_inputs diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index dab0d60ad56..cb35aab66cb 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -269,7 +269,7 @@ def __call__( logger.warning( "You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special " "image tokens in the text, as many tokens as there are images per each text. It is recommended to " - "add `` tokens in the very beginning of your text and `` token after that. For this call, we will infer how many images " + "add `` tokens in the very beginning of your text. For this call, we will infer how many images " "each text has and add special tokens." ) @@ -304,9 +304,16 @@ def __call__( ] images = make_batched_images(images) else: - text = [sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length) for sample in text] - input_strings = [f"{sample}\n" for sample in text] - + expanded_samples = [] + for sample in text: + expanded_sample = sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length) + bos_rfind_index = expanded_sample.rfind(IMAGE_TOKEN) + bos_index = bos_rfind_index + len(IMAGE_TOKEN) if bos_rfind_index != -1 else 0 + expanded_sample = ( + expanded_sample[:bos_index] + self.tokenizer.bos_token + expanded_sample[bos_index:] + ) + expanded_samples.append(expanded_sample) + input_strings = [f"{sample}\n" for sample in expanded_samples] pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] # max_length has to account for the image tokens diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index ce44436a20a..5ffea7ffe55 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -347,6 +347,11 @@ def test_flash_attn_2_fp32_ln(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass + # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow + @unittest.skip("PaliGemma is not compatible with end-to-end generation compilation") + def test_generate_compile_fullgraph(self): + pass + @slow @require_torch diff --git a/tests/models/paligemma/test_processor_paligemma.py b/tests/models/paligemma/test_processor_paligemma.py index 245aff59412..e301bf304b1 100644 --- a/tests/models/paligemma/test_processor_paligemma.py +++ b/tests/models/paligemma/test_processor_paligemma.py @@ -63,8 +63,8 @@ def test_text_with_image_tokens(self): tokenizer = self.get_component("tokenizer") processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) - text_multi_images = "Dummy text!" - text_single_image = "Dummy text!" + text_multi_images = "Dummy text!" + text_single_image = "Dummy text!" text_no_image = "Dummy text!" image = self.prepare_image_inputs() @@ -85,7 +85,7 @@ def test_text_with_image_tokens(self): self.assertTrue(out_noimage[k].tolist() == out_multiimages[k].tolist()) text_batched = ["Dummy text!", "Dummy text!"] - text_batched_with_image = ["Dummy text!", "Dummy text!"] + text_batched_with_image = ["Dummy text!", "Dummy text!"] out_images = processor(text=text_batched_with_image, images=[image, image], return_tensors="np") out_noimage_nested = processor(text=text_batched, images=[[image], [image]], return_tensors="np") out_noimage = processor(text=text_batched, images=[image, image], return_tensors="np")