From c3654514a908eed6a119acdface07fcdcb540fd6 Mon Sep 17 00:00:00 2001 From: xujunda2024 Date: Sun, 26 Jan 2025 23:58:25 +0800 Subject: [PATCH] add SigLIP Model --- mindnlp/transformers/models/__init__.py | 2 + .../models/auto/tokenization_auto.py | 2 +- mindnlp/transformers/models/mpnet/__init__.py | 4 +- mindnlp/transformers/models/mt5/__init__.py | 4 +- .../transformers/models/siglip/__init__.py | 31 + .../models/siglip/configuration_siglip.py | 358 +++++ .../models/siglip/image_processing_siglip.py | 244 ++++ .../models/siglip/modeling_siglip.py | 1269 +++++++++++++++++ .../models/siglip/processing_siglip.py | 145 ++ .../models/siglip/tokenization_siglip.py | 393 +++++ mindnlp/utils/generic.py | 83 +- tests/transformers/models/siglip/__init__.py | 0 .../siglip/test_image_processing_siglip.py | 127 ++ .../models/siglip/test_modeling_siglip.py | 989 +++++++++++++ .../models/siglip/test_tokenization_siglip.py | 455 ++++++ 15 files changed, 4102 insertions(+), 4 deletions(-) create mode 100644 mindnlp/transformers/models/siglip/__init__.py create mode 100644 mindnlp/transformers/models/siglip/configuration_siglip.py create mode 100644 mindnlp/transformers/models/siglip/image_processing_siglip.py create mode 100644 mindnlp/transformers/models/siglip/modeling_siglip.py create mode 100644 mindnlp/transformers/models/siglip/processing_siglip.py create mode 100644 mindnlp/transformers/models/siglip/tokenization_siglip.py create mode 100644 tests/transformers/models/siglip/__init__.py create mode 100644 tests/transformers/models/siglip/test_image_processing_siglip.py create mode 100644 tests/transformers/models/siglip/test_modeling_siglip.py create mode 100644 tests/transformers/models/siglip/test_tokenization_siglip.py diff --git a/mindnlp/transformers/models/__init__.py b/mindnlp/transformers/models/__init__.py index 722aa0f7d..f86a85dc6 100644 --- a/mindnlp/transformers/models/__init__.py +++ b/mindnlp/transformers/models/__init__.py @@ -449,6 +449,7 @@ from .seggpt import * from .sew import * from .sew_d import * +from .siglip import * from .speech_encoder_decoder import * from .speech_to_text import * from .speech_to_text_2 import * @@ -694,6 +695,7 @@ __all__.extend(seggpt.__all__) __all__.extend(sew.__all__) __all__.extend(sew_d.__all__) +__all__.extend(siglip.__all__) __all__.extend(speech_encoder_decoder.__all__) __all__.extend(speech_to_text.__all__) __all__.extend(speech_to_text_2.__all__) diff --git a/mindnlp/transformers/models/auto/tokenization_auto.py b/mindnlp/transformers/models/auto/tokenization_auto.py index 1ad0adbe4..3b849dc40 100644 --- a/mindnlp/transformers/models/auto/tokenization_auto.py +++ b/mindnlp/transformers/models/auto/tokenization_auto.py @@ -287,7 +287,7 @@ ("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("mgp-str", ("MgpstrTokenizer", None)), - ("minicpm3", ("MiniCPMTokenizer", None)), + ("minicpm3", ("MiniCPM3Tokenizer", None)), ( "mistral", ( diff --git a/mindnlp/transformers/models/mpnet/__init__.py b/mindnlp/transformers/models/mpnet/__init__.py index 7e90bc331..6d85f2cbd 100644 --- a/mindnlp/transformers/models/mpnet/__init__.py +++ b/mindnlp/transformers/models/mpnet/__init__.py @@ -15,12 +15,14 @@ """ MPNet Model. """ -from . import configuration_mpnet, modeling_mpnet, tokenization_mpnet +from . import configuration_mpnet, modeling_mpnet, tokenization_mpnet, tokenization_mpnet_fast from .configuration_mpnet import * from .modeling_mpnet import * from .tokenization_mpnet import * +from .tokenization_mpnet_fast import * __all__ = [] __all__.extend(configuration_mpnet.__all__) __all__.extend(modeling_mpnet.__all__) __all__.extend(tokenization_mpnet.__all__) +__all__.extend(tokenization_mpnet_fast.__all__) diff --git a/mindnlp/transformers/models/mt5/__init__.py b/mindnlp/transformers/models/mt5/__init__.py index 26483cb2c..4b8807d77 100644 --- a/mindnlp/transformers/models/mt5/__init__.py +++ b/mindnlp/transformers/models/mt5/__init__.py @@ -15,10 +15,12 @@ """ T5 Model init """ -from . import configuration_mt5, modeling_mt5 +from . import configuration_mt5, modeling_mt5, tokenization_mt5 from .configuration_mt5 import * from .modeling_mt5 import * +from .tokenization_mt5 import * __all__ = [] __all__.extend(modeling_mt5.__all__) __all__.extend(configuration_mt5.__all__) +__all__.extend(tokenization_mt5.__all__) diff --git a/mindnlp/transformers/models/siglip/__init__.py b/mindnlp/transformers/models/siglip/__init__.py new file mode 100644 index 000000000..b5264e188 --- /dev/null +++ b/mindnlp/transformers/models/siglip/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +SigLip Model. +""" + +from . import configuration_siglip, image_processing_siglip, modeling_siglip, processing_siglip, tokenization_siglip + +from .configuration_siglip import * +from .image_processing_siglip import * +from .modeling_siglip import * +from .processing_siglip import * +from .tokenization_siglip import * + +__all__ = [] +__all__.extend(modeling_siglip.__all__) +__all__.extend(configuration_siglip.__all__) +__all__.extend(image_processing_siglip.__all__) +__all__.extend(processing_siglip.__all__) +__all__.extend(tokenization_siglip.__all__) \ No newline at end of file diff --git a/mindnlp/transformers/models/siglip/configuration_siglip.py b/mindnlp/transformers/models/siglip/configuration_siglip.py new file mode 100644 index 000000000..c67163e5f --- /dev/null +++ b/mindnlp/transformers/models/siglip/configuration_siglip.py @@ -0,0 +1,358 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Siglip model configuration""" + +import os +from typing import Union + +from mindnlp.utils import logging + +from ...configuration_utils import PretrainedConfig + + +logger = logging.get_logger(__name__) + +SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/siglip-base-patch16-224": "https://hf-mirror.com/google/siglip-base-patch16-224/resolve/main/config.json", + # See all Siglip models at https://hf-mirror.com/models?filter=siglip +} + +class SiglipTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a + Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + + Example: + + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipTextConfig() + + >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + """ + Creates a SiglipTextConfig instance from a pretrained model. + + Args: + cls (type): The class object. + pretrained_model_name_or_path (Union[str, os.PathLike]): The name or path of the pretrained model. + + Returns: + PretrainedConfig: A SiglipTextConfig instance initialized with the configuration specified by the pretrained model. + + Raises: + TypeError: If the input parameters are not of the expected types. + ValueError: If the configuration dictionary does not contain the required information. + Warning: If the model type being used for instantiation does not match the class's model type, which may lead to errors. + """ + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + """ + Load a pretrained configuration from a given model name or path. + + Args: + cls (class): The class object. + pretrained_model_name_or_path (Union[str, os.PathLike]): The name or path of the pretrained model. + It can be either a string representing the name of the model or a path-like object pointing to the model location. + + Returns: + PretrainedConfig: The loaded pretrained configuration. + + Raises: + None. + + This method is a class method that allows loading a pretrained configuration. It takes in the class object 'cls' + and the name or path of the pretrained model 'pretrained_model_name_or_path' as parameters. The method returns an instance + of type 'PretrainedConfig', which represents the loaded pretrained configuration. + + The 'pretrained_model_name_or_path' parameter can be either a string representing the name of the pretrained model + or a path-like object pointing to the location of the model. It is used to identify and locate the pretrained model + that needs to be loaded. + + Note: If the loaded configuration belongs to the 'siglip' model type, the 'config_dict' will be updated to use the + 'vision_config' sub-dictionary. Additionally, if the 'model_type' attribute is present in the 'cls' class and + the loaded configuration's 'model_type' is different from 'cls.model_type', a warning will be logged indicating + that instantiating a model of different types may lead to errors. + + Example: + ```python + >>> config = SiglipVisionConfig.from_pretrained("siglip_model") + ... + ``` + In the above example, the 'from_pretrained' method is called on the 'SiglipVisionConfig' class to load the pretrained + configuration of the 'siglip_model'. The resulting configuration is stored in the 'config' variable. + """ + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + +class SiglipConfig(PretrainedConfig): + r""" + [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to + instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import SiglipConfig, SiglipModel + + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipConfig() + + >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig + >>> from transformers import SiglipTextConfig, SiglipVisionConfig + + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = SiglipVisionConfig() + + >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip" + sub_configs = {"text_config": SiglipTextConfig, "vision_config": SiglipVisionConfig} + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs): + r""" + Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision + model configuration. + + Returns: + [`SiglipConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +__all__ = [ + "SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP", + "SiglipConfig", + "SiglipTextConfig", + "SiglipVisionConfig" +] diff --git a/mindnlp/transformers/models/siglip/image_processing_siglip.py b/mindnlp/transformers/models/siglip/image_processing_siglip.py new file mode 100644 index 000000000..92308c37b --- /dev/null +++ b/mindnlp/transformers/models/siglip/image_processing_siglip.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SigLIP.""" + +from typing import Dict, List, Optional, Union + +from mindnlp.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging +from mindnlp.configs import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) + + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class SiglipImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: bool = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + height, width = size["height"], size["width"] + images = [ + resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["SiglipImageProcessor"] diff --git a/mindnlp/transformers/models/siglip/modeling_siglip.py b/mindnlp/transformers/models/siglip/modeling_siglip.py new file mode 100644 index 000000000..d0033dd42 --- /dev/null +++ b/mindnlp/transformers/models/siglip/modeling_siglip.py @@ -0,0 +1,1269 @@ +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MindNLP Siglip model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import mindspore.ops +import numpy as np +import mindspore +from mindnlp.core import nn, ops, Parameter +from mindnlp.core.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from mindnlp.core.nn.init import initializer, _calculate_fan_in_and_fan_out + +from ....common.activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import PreTrainedModel +from mindnlp.utils import ( + ModelOutput, + logging +) +from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig + + +logger = logging.get_logger(__name__) + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from MindSpore official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + nn.init.uniform_(tensor, 2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor = ops.erfinv(tensor) + + # Transform to proper mean, std + tensor = tensor * std * math.sqrt(2.0) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor = ops.clamp(tensor, min=a, max=b) + + +def trunc_normal_tf_( + tensor: mindspore.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> mindspore.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `mindspore.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor = tensor * std + tensor.add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + nn.init.normal_(tensor, std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + nn.init.uniform_(tensor, -bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`mindspore.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`mindspore.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(mindspore.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `mindspore.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(mindspore.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `mindspore.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[mindspore.Tensor] = None + last_hidden_state: mindspore.Tensor = None + hidden_states: Optional[Tuple[mindspore.Tensor, ...]] = None + attentions: Optional[Tuple[mindspore.Tensor, ...]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`mindspore.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`mindspore.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(mindspore.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `mindspore.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(mindspore.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `mindspore.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[mindspore.Tensor] = None + last_hidden_state: mindspore.Tensor = None + hidden_states: Optional[Tuple[mindspore.Tensor, ...]] = None + attentions: Optional[Tuple[mindspore.Tensor, ...]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + """ + Args: + loss (`mindspore.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`mindspore.Tensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`mindspore.Tensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`mindspore.Tensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds (`mindspore.Tensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[mindspore.Tensor] = None + logits_per_image: mindspore.Tensor = None + logits_per_text: mindspore.Tensor = None + text_embeds: mindspore.Tensor = None + image_embeds: mindspore.Tensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", ops.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: mindspore.Tensor, height: int, width: int) -> mindspore.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if num_patches == num_positions and height == width: # and not mindspore.jit.is_tracing() + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, pixel_values: mindspore.Tensor, interpolate_pos_encoding=False) -> mindspore.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = ops.transpose(ops.flatten(patch_embeds, 2), 1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", ops.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + inputs_embeds: Optional[mindspore.Tensor] = None, + ) -> mindspore.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[mindspore.Tensor, Optional[mindspore.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).swapaxes(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).swapaxes(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).swapaxes(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ops.matmul(query_states, key_states.swapaxes(2, 3)) * self.scale + + if attn_weights.shape != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is not None: + if attention_mask.shape != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.shape}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = ops.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = ops.matmul(attn_weights, value_states) + + if attn_output.shape != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.shape}" + ) + + attn_output = attn_output.swapaxes(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +SIGLIP_ATTENTION_CLASSES = { + "eager": SiglipAttention, +} + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: mindspore.Tensor) -> mindspore.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](config=config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: mindspore.Tensor, + attention_mask: mindspore.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[mindspore.Tensor]: + """ + Args: + hidden_states (`mindspore.Tensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`mindspore.Tensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "SiglipTextEmbeddings", + "SiglipEncoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = False + _supports_sdpa = False + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe) + nn.init.xavier_uniform_(module.attention.in_proj_weight) + nn.init.zeros_(module.attention.in_proj_bias) + elif isinstance(module, SiglipModel): + # logit_scale_init = ops.log(mindspore.Tensor(1.0)) + module.logit_scale.assign_value(ops.log(initializer('ones', module.logit_scale.shape, module.logit_scale.dtype))) + nn.init.zeros_(module.logit_bias) + elif isinstance(module, SiglipForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[mindspore.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`mindspore.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, embed_dim) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + input_ids: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.shape + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + def forward( + self, + input_ids: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + if not return_dict: + return (last_hidden_state, pooler_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(ops.randn(1, 1, config.hidden_size)) + self.attention = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, axis=0) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise TypeError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention implementation + text_model = SiglipTextModel._from_config(text_config) + vision_model = SiglipVisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = Parameter(ops.randn(1)) + self.logit_bias = Parameter(ops.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + def get_text_features( + self, + input_ids: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> mindspore.Tensor: + r""" + Returns: + text_features (`mindspore.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + return pooled_output + + def get_image_features( + self, + pixel_values: Optional[mindspore.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> mindspore.Tensor: + r""" + Returns: + image_features (`mindspore.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + def forward( + self, + input_ids: Optional[mindspore.Tensor] = None, + pixel_values: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Union[Tuple, SiglipOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / mindspore.ops.norm(image_embeds, 2, dim=-1, keepdim=True) + text_embeds = text_embeds / mindspore.ops.norm(text_embeds, 2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = ( + ops.matmul(text_embeds, image_embeds.t()) * ops.exp(self.logit_scale) + + self.logit_bias + ) + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 + eye = ops.eye(logits_per_text.shape[0]) + m1_diag1 = -ops.ones_like(logits_per_text) + 2 * eye + loglik = nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -ops.sum(loglik, dim=-1) + loss = nll.mean() + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class SiglipForImageClassification(SiglipPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: SiglipConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = SiglipVisionModel._from_config(config.vision_config) + self.vision_model = vision_model.vision_model + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + pixel_values: Optional[mindspore.Tensor] = None, + labels: Optional[mindspore.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`mindspore.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, SiglipForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `SiglipModel` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vision_model( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + sequence_output = outputs[0] + + # average pool the patch tokens + sequence_output = ops.mean(sequence_output, dim=1) + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == mindspore.int64 or labels.dtype == mindspore.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "SiglipModel", + "SiglipPreTrainedModel", + "SiglipTextModel", + "SiglipVisionModel", + "SiglipForImageClassification", +] diff --git a/mindnlp/transformers/models/siglip/processing_siglip.py b/mindnlp/transformers/models/siglip/processing_siglip.py new file mode 100644 index 000000000..242b021c3 --- /dev/null +++ b/mindnlp/transformers/models/siglip/processing_siglip.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for SigLIP. +""" + +from typing import List, Optional, Union + +from mindnlp.utils import TensorType +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy + + +class SiglipProcessor(ProcessorMixin): + r""" + Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor. + + [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the + [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information. + + Args: + image_processor ([`SiglipImageProcessor`]): + The image processor is a required input. + tokenizer ([`SiglipTokenizer`]): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "SiglipImageProcessor" + tokenizer_class = "SiglipTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: int = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.MINDSPORE, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` argument to + SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer( + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + ) + + if images is not None: + image_features = self.image_processor(images, return_tensors=return_tensors) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features.pixel_values + return encoding + elif text is not None: + return encoding + else: + return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["SiglipProcessor"] diff --git a/mindnlp/transformers/models/siglip/tokenization_siglip.py b/mindnlp/transformers/models/siglip/tokenization_siglip.py new file mode 100644 index 000000000..e57aa92a2 --- /dev/null +++ b/mindnlp/transformers/models/siglip/tokenization_siglip.py @@ -0,0 +1,393 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for SigLIP model.""" + +import os +import re +import string +import warnings +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from mindnlp.utils import logging, requires_backends +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import AddedToken + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput + + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + +SPIECE_UNDERLINE = "▁" + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "google/siglip-base-patch16-224": "" + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "google/siglip-base-patch16-224": 77, +} + +PRETRAINED_INIT_CONFIGURATION = { + "google/siglip-base-patch16-224": {}, +} + +class SiglipTokenizer(PreTrainedTokenizer): + """ + Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + model_max_length (`int`, *optional*, defaults to 64): + The maximum length (in number of tokens) for model inputs. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + additional_special_tokens=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + model_max_length=64, + do_lower_case=True, + **kwargs, + ) -> None: + requires_backends(self, "protobuf") + + pad_token = ( + AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(pad_token, str) + else pad_token + ) + unk_token = ( + AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(unk_token, str) + else unk_token + ) + eos_token = ( + AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(eos_token, str) + else eos_token + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.do_lower_case = do_lower_case + self.vocab_file = vocab_file + + self.sp_model = self.get_spm_processor() + self.vocab_file = vocab_file + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + model_max_length=model_max_length, + do_lower_case=do_lower_case, + **kwargs, + ) + + def get_spm_processor(self): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf() + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size + def vocab_size(self): + return self.sp_model.get_piece_size() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def remove_punctuation(self, text: str) -> str: + return text.translate(str.maketrans("", "", string.punctuation)) + + # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 + def canonicalize_text(self, text, *, keep_punctuation_exact_string=None): + """Returns canonicalized `text` (puncuation removed). + + Args: + text (`str`): + String to be canonicalized. + keep_punctuation_exact_string (`str`, *optional*): + If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}' + (but will still remove '{' and '}' that appear separately). + """ + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string) + ) + else: + text = self.remove_punctuation(text) + text = re.sub(r"\s+", " ", text) + text = text.strip() + + return text + + def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. + """ + tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. + + For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`. + + Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + text = self.canonicalize_text(text, keep_punctuation_exact_string=None) + tokens = self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["SiglipTokenizer"] diff --git a/mindnlp/utils/generic.py b/mindnlp/utils/generic.py index 2d6c8422a..11cec4248 100644 --- a/mindnlp/utils/generic.py +++ b/mindnlp/utils/generic.py @@ -17,10 +17,13 @@ Generic utils. """ import inspect +import warnings +from typing import Optional from enum import Enum +from functools import wraps from collections import OrderedDict, UserDict from dataclasses import fields -from typing import Any, Tuple, ContextManager, List +from typing import Any, Tuple, ContextManager, Optional, List from contextlib import ExitStack import numpy as np @@ -566,3 +569,81 @@ def can_return_loss(model_class): return True return False + +def filter_out_non_signature_kwargs(extra: Optional[list] = None): + """ + Decorator to filter out named arguments that are not in the function signature. + + This decorator ensures that only the keyword arguments that match the function's signature, or are specified in the + `extra` list, are passed to the function. Any additional keyword arguments are filtered out and a warning is issued. + + Parameters: + extra (`Optional[list]`, *optional*): + A list of extra keyword argument names that are allowed even if they are not in the function's signature. + + Returns: + Callable: + A decorator that wraps the function and filters out invalid keyword arguments. + + Example usage: + + ```python + @filter_out_non_signature_kwargs(extra=["allowed_extra_arg"]) + def my_function(arg1, arg2, **kwargs): + print(arg1, arg2, kwargs) + + my_function(arg1=1, arg2=2, allowed_extra_arg=3, invalid_arg=4) + # This will print: 1 2 {"allowed_extra_arg": 3} + # And issue a warning: "The following named arguments are not valid for `my_function` and were ignored: 'invalid_arg'" + ``` + """ + extra = extra or [] + extra_params_to_pass = set(extra) + + def decorator(func): + sig = inspect.signature(func) + function_named_args = set(sig.parameters.keys()) + valid_kwargs_to_pass = function_named_args.union(extra_params_to_pass) + + # Required for better warning message + is_instance_method = "self" in function_named_args + is_class_method = "cls" in function_named_args + + # Mark function as decorated + func._filter_out_non_signature_kwargs = True + + @wraps(func) + def wrapper(*args, **kwargs): + valid_kwargs = {} + invalid_kwargs = {} + + for k, v in kwargs.items(): + if k in valid_kwargs_to_pass: + valid_kwargs[k] = v + else: + invalid_kwargs[k] = v + + if invalid_kwargs: + invalid_kwargs_names = [f"'{k}'" for k in invalid_kwargs.keys()] + invalid_kwargs_names = ", ".join(invalid_kwargs_names) + + # Get the class name for better warning message + if is_instance_method: + cls_prefix = args[0].__class__.__name__ + "." + elif is_class_method: + cls_prefix = args[0].__name__ + "." + else: + cls_prefix = "" + + warnings.warn( + f"The following named arguments are not valid for `{cls_prefix}{func.__name__}`" + f" and were ignored: {invalid_kwargs_names}", + UserWarning, + stacklevel=2, + ) + + return func(*args, **valid_kwargs) + + return wrapper + + return decorator \ No newline at end of file diff --git a/tests/transformers/models/siglip/__init__.py b/tests/transformers/models/siglip/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/transformers/models/siglip/test_image_processing_siglip.py b/tests/transformers/models/siglip/test_image_processing_siglip.py new file mode 100644 index 000000000..8aa67d546 --- /dev/null +++ b/tests/transformers/models/siglip/test_image_processing_siglip.py @@ -0,0 +1,127 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from mindnlp.utils import is_vision_available +from mindnlp.utils.testing_utils import require_mindspore, require_vision + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_vision_available(): + from mindnlp.transformers.models.siglip import SiglipImageProcessor + + +class SiglipImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + super().__init__() + size = size if size is not None else {"height": 18, "width": 18} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_rescale": self.do_rescale, + "rescale_factor": self.rescale_factor, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + } + + def expected_output_image_shape(self, images): + return self.num_channels, self.size["height"], self.size["width"] + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_mindspore +@require_vision +# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest with CLIP->Siglip +class SiglipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = SiglipImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = SiglipImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + # Ignore copy + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "resample")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + + # Ignore copy + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 18}) + + image_processor = self.image_processing_class.from_dict( + self.image_processor_dict, size={"height": 84, "width": 84} + ) + self.assertEqual(image_processor.size, {"height": 84, "width": 84}) + + @unittest.skip(reason="not supported") + # Ignore copy + def test_call_numpy_4_channels(self): + pass diff --git a/tests/transformers/models/siglip/test_modeling_siglip.py b/tests/transformers/models/siglip/test_modeling_siglip.py new file mode 100644 index 000000000..337154691 --- /dev/null +++ b/tests/transformers/models/siglip/test_modeling_siglip.py @@ -0,0 +1,989 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch SigLIP model.""" + +import inspect +import os +import tempfile +import unittest +from typing import Tuple + +import numpy as np +import requests +from parameterized import parameterized +from pytest import mark + +from mindnlp.utils import ( + is_mindspore_available, + is_vision_available, +) +from mindnlp.utils.testing_utils import ( + require_mindspore, + require_mindspore_gpu, + require_vision, + slow, +) +from mindnlp.transformers.models.siglip.configuration_siglip import ( + SiglipConfig, + SiglipTextConfig, + SiglipVisionConfig +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, + is_flaky, + random_attention_mask, +) +# from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_mindspore_available(): + import mindspore + from mindnlp.core import nn, ops + + from mindnlp.transformers.models.siglip.modeling_siglip import ( + SiglipForImageClassification, + SiglipModel, + SiglipTextModel, + SiglipVisionModel + ) + +if is_vision_available(): + from PIL import Image + + from mindnlp.transformers.models.siglip.processing_siglip import SiglipProcessor + + +class SiglipModelTesterMixin(ModelTesterMixin): + def test_sdpa_can_dispatch_composite_models(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # Load the model with eager attention + model_eager = model_class.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.set_train(False) + + if hasattr(model_eager, "vision_model") and hasattr(model_eager, "text_model"): + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.text_model.config._attn_implementation == "eager") + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + # def test_sdpa_can_dispatch_composite_models(self): + # for model_class in self.all_model_classes: + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # model = model_class(config) + + # with tempfile.TemporaryDirectory() as tmpdirname: + # model.save_pretrained(tmpdirname) + + # # Load the model with SDPA + # model_sdpa = model_class.from_pretrained(tmpdirname,attn_implementation="sdpa",) + # model_sdpa = model_sdpa.set_train(False) + + # # Load model with eager attention + # model_eager = model_class.from_pretrained( + # tmpdirname, + # attn_implementation="eager", + # ) + # model_eager = model_eager.set_train(False) + + # # SigLip has one shared cls attr for all models, so we assign both submodels heer + # vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager" + + # if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"): + # self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) + # self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn) + # self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + # self.assertTrue(model_eager.text_model.config._attn_implementation == "eager") + + # self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + # self.assertTrue(model_eager.config._attn_implementation == "eager") + + # for name, submodule in model_eager.named_modules(): + # class_name = submodule.__class__.__name__ + # if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + # raise ValueError("The eager model should not have SDPA attention layers") + + # has_sdpa = False + # for name, submodule in model_sdpa.named_modules(): + # class_name = submodule.__class__.__name__ + # if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + # has_sdpa = True + # break + # if not has_sdpa and model_sdpa.config.model_type != "falcon": + # raise ValueError("The SDPA model should have SDPA attention layers") + + def test_eager_matches_sdpa_inference( + self, + ms_dtype: str, + use_attention_mask_options: Tuple[bool,...] = (True, False), + logit_keys: Tuple[str,...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), + ): + # Convert to torch dtype + dtypes = { + "float16": mindspore.float16, + "float32": mindspore.float32, + } + ms_dtype = dtypes[ms_dtype] + + atols = { + mindspore.float32: 1e-5, + mindspore.float16: 5e-3, + } + rtols = { + mindspore.float32: 1e-4, + mindspore.float16: 5e-3, + } + + atol = atols[ms_dtype] + rtol = rtols[ms_dtype] + + def get_mean_reldiff(msg, current_case, x, ref, atol, rtol): + return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # Load the model with eager attention + model_eager = model_class.from_pretrained( + tmpdirname, + ms_dtype=ms_dtype, + attn_implementation="eager", + ) + model_eager = model_eager.set_train(False) + + cases = [ + (use_mask, output_attentions, batch_size) + for use_mask in use_attention_mask_options + for output_attentions in [True, False] + for batch_size in [1, 5] + ] + fail_cases = [] + + for use_mask, output_attentions, batch_size in cases: + processed_inputs = inputs_dict.copy() + + # convert to ms_dtype + if "pixel_values" in processed_inputs: + processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(dtype=ms_dtype) + + # slice for different batch sizes + for key in ["pixel_values", "input_ids", "attention_mask"]: + if key in processed_inputs: + processed_inputs[key] = processed_inputs[key][:batch_size] + + # set attention mask with left padding + if not use_mask: + processed_inputs.pop("attention_mask", None) + else: + dummy_attention_mask = processed_inputs["attention_mask"] + dummy_attention_mask[:] = 1 + processed_inputs["attention_mask"] = dummy_attention_mask + + processed_inputs["output_attentions"] = output_attentions + processed_inputs["output_hidden_states"] = True + + current_case = ( + f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}" + ) + + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + + try: + outputs_eager = model_eager(**prepared_inputs) + except Exception as e: + fail_cases.append(f"{current_case}: {e}") + continue + + for key in logit_keys: + eager_logits = outputs_eager[key] + + if use_mask: + eager_logits = eager_logits[:, 1:] + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + # def test_eager_matches_sdpa_inference( + # self, + # ms_dtype: str, + # use_attention_mask_options: Tuple[bool, ...] = (True, False), + # logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), + # ): + # if not self.all_model_classes[0]._supports_sdpa: + # self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + # # Convert to torch dtype + # dtypes = { + # "float16": mindspore.float16, + # "bfloat16": mindspore.bfloat16, + # "float32": mindspore.float32, + # } + # ms_dtype = dtypes[ms_dtype] + + # atols = { + # mindspore.float32: 1e-5, + # mindspore.bfloat16: 3e-2, + # mindspore.float16: 5e-3, + # } + # rtols = { + # mindspore.float32: 1e-4, + # mindspore.bfloat16: 3e-2, + # mindspore.float16: 5e-3, + # } + + # atol = atols[ms_dtype] + # rtol = rtols[ms_dtype] + + # def get_mean_reldiff(msg, current_case, x, ref, atol, rtol): + # return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + # for model_class in self.all_model_classes: + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # model = model_class(config) + + # with tempfile.TemporaryDirectory() as tmpdirname: + # model.save_pretrained(tmpdirname) + + # # Load the model with SDPA + # model_sdpa = model_class.from_pretrained(tmpdirname, ms_dtype=ms_dtype,attn_implementation="sdpa") + # model_sdpa = model_sdpa.set_train(False) + + # # Load model with eager attention + # model_eager = model_class.from_pretrained( + # tmpdirname, + # ms_dtype=ms_dtype, + # attn_implementation="eager", + # ) + # model_eager = model_eager.set_train(False) + + # # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time, + # # but it would be nicer to have an efficient way to use parameterized.expand + # cases = [ + # (use_mask, output_attentions, sdpa_backend, batch_size) + # for use_mask in use_attention_mask_options + # for output_attentions in [True, False] + # for sdpa_backend in [ + # SDPBackend.MATH, + # [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH], + # [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], + # [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], + # ] + # for batch_size in [1, 5] + # ] + # fail_cases = [] + + # for use_mask, output_attentions, sdpa_backend, batch_size in cases: + # processed_inputs = inputs_dict.copy() + + # # convert to ms_dtype + # if "pixel_values" in processed_inputs: + # processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(dtype=ms_dtype) + + # # slice for different batch sizes + # for key in ["pixel_values", "input_ids", "attention_mask"]: + # if key in processed_inputs: + # processed_inputs[key] = processed_inputs[key][:batch_size] + + # # set attention mask with left padding + # if not use_mask: + # processed_inputs.pop("attention_mask", None) + # else: + # dummy_attention_mask = processed_inputs["attention_mask"] + # dummy_attention_mask[:] = 1 + # dummy_attention_mask[:, :1] = 0 + # processed_inputs["attention_mask"] = dummy_attention_mask + + # processed_inputs["output_attentions"] = output_attentions + # processed_inputs["output_hidden_states"] = True + + # current_case = ( + # f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}" + # ) + + # prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + + # try: + # with sdpa_kernel(sdpa_backend): + # outputs_eager = model_eager(**prepared_inputs) + # outputs_sdpa = model_sdpa(**prepared_inputs) + # except Exception as e: + # fail_cases.append(f"{current_case}: {e}") + # continue + + # for key in logit_keys: + # eager_logits = outputs_eager[key] + # sdpa_logits = outputs_sdpa[key] + + # if use_mask: + # eager_logits = eager_logits[:, 1:] + # sdpa_logits = sdpa_logits[:, 1:] + + # is_close = np.allclose(eager_logits.asnumpy(), sdpa_logits.asnumpy(), atol=atol, rtol=rtol) + # if not is_close: + # fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol)) + + # self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + +class SiglipVisionModelTester: + def __init__( + self, + parent, + batch_size=12, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + initializer_range=0.02, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.scope = scope + + # in ViT, the seq length equals the number of patches + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + + # Copied from tests.models.clip.test_modeling_clip.CLIPVisionModelTester.prepare_config_and_inputs + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return SiglipVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values): + model = SiglipVisionModel(config=config) + model.set_train(False) + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = (self.image_size, self.image_size) + patch_size = (self.patch_size, self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + # Copied from tests.models.clip.test_modeling_clip.CLIPVisionModelTester.prepare_config_and_inputs_for_common + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_mindspore +class SiglipVisionModelTest(SiglipModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SIGLIP does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (SiglipVisionModel,) if is_mindspore_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + # MP works but offload doesn't work when the MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + + def setUp(self): + self.model_tester = SiglipVisionModelTester(self) + self.config_tester = ConfigTester( + self, config_class=SiglipVisionConfig, has_text_modality=False, hidden_size=37 + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="SIGLIP does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="SiglipVisionModel does not support standalone training") + def test_training(self): + pass + + @unittest.skip(reason="SiglipVisionModel does not support standalone training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SiglipVisionModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SiglipVisionModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="SiglipVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="SiglipVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + def test_initialization(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "google/siglip-base-patch16-224" + model = SiglipVisionModel.from_pretrained(model_name, from_pt=True) + self.assertIsNotNone(model) + + @parameterized.expand([("float16",), ("float32",)]) + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, ms_dtype: str): + super().test_eager_matches_sdpa_inference( + ms_dtype=ms_dtype, + logit_keys=("pooler_output", "last_hidden_state"), + use_attention_mask_options=(False,), + ) + + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + + +class SiglipTextModelTester: + def __init__( + self, + parent, + batch_size=12, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + max_position_embeddings=512, + initializer_range=0.02, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.scope = scope + + # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTester.prepare_config_and_inputs + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + if input_mask is not None: + batch_size, seq_length = input_mask.shape + rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + input_mask[batch_idx, :int(start_index)] = 1 + input_mask[batch_idx, int(start_index):] = 0 + + config = self.get_config() + + return config, input_ids, input_mask + + def get_config(self): + return SiglipTextConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, input_ids, input_mask): + model = SiglipTextModel(config=config) + model.set_train(False) + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTester.prepare_config_and_inputs_for_common + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, input_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_mindspore +class SiglipTextModelTest(SiglipModelTesterMixin, unittest.TestCase): + all_model_classes = (SiglipTextModel,) if is_mindspore_available() else () + fx_compatible = False + test_pruning = False + test_head_masking = False + model_split_percents = [0.5, 0.8, 0.9] + + # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.setUp with CLIP->Siglip + def setUp(self): + self.model_tester = SiglipTextModelTester(self) + self.config_tester = ConfigTester(self, config_class=SiglipTextConfig, hidden_size=37) + + # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_config + def test_config(self): + self.config_tester.run_common_tests() + + # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_model + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="SiglipTextModel does not support standalone training") + def test_training(self): + pass + + @unittest.skip(reason="SiglipTextModel does not support standalone training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SiglipTextModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SiglipTextModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Siglip does not use inputs_embeds") + # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_inputs_embeds + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="SiglipTextModel has no base class and is not available in MODEL_MAPPING") + # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_save_load_fast_init_from_base + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="SiglipTextModel has no base class and is not available in MODEL_MAPPING") + # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_save_load_fast_init_to_base + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + def test_initialization(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "google/siglip-base-patch16-224" + model = SiglipTextModel.from_pretrained(model_name, from_pt=True) + self.assertIsNotNone(model) + + @parameterized.expand([("float16",), ("float32",)]) + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, ms_dtype: str): + super().test_eager_matches_sdpa_inference( + ms_dtype=ms_dtype, + logit_keys=("pooler_output", "last_hidden_state"), + use_attention_mask_options=(False, True), + ) + + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + + +class SiglipModelTester: + def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + + self.parent = parent + self.text_model_tester = SiglipTextModelTester(parent, **text_kwargs) + self.vision_model_tester = SiglipVisionModelTester(parent, **vision_kwargs) + self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test + self.is_training = is_training + + # Copied from tests.models.clip.test_modeling_clip.CLIPModelTester.prepare_config_and_inputs + def prepare_config_and_inputs(self): + text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, attention_mask, pixel_values + + def get_config(self): + return SiglipConfig.from_text_vision_configs( + self.text_model_tester.get_config(), + self.vision_model_tester.get_config(), + ) + + def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): + model = SiglipModel(config).set_train(False) + result = model(input_ids, pixel_values, attention_mask) + self.parent.assertEqual( + result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size) + ) + self.parent.assertEqual( + result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "return_loss": False, + } + return config, inputs_dict + + +@require_mindspore +class SiglipModelTest(SiglipModelTesterMixin, unittest.TestCase): + all_model_classes = (SiglipModel,) if is_mindspore_available() else () + pipeline_model_mapping = {"feature-extraction": SiglipModel} if is_mindspore_available() else {} + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + # MP works but offload doesn't work when the MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + _is_composite = True + + def setUp(self): + self.model_tester = SiglipModelTester(self) + + # Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_model + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + # Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_hidden_states_output + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + # Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_inputs_embeds + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + # Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_retain_grad_hidden_states_attentions + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="SiglipModel does not have input/output embeddings") + # Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_model_get_set_embeddings + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + def test_initialization(self): + pass + + # Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_load_vision_text_config with CLIP->Siglip + def test_load_vision_text_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Save SiglipConfig and check if we can load SiglipVisionConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + vision_config = SiglipVisionConfig.from_pretrained(tmp_dir_name, force_download=True) + self.assertDictEqual( + config.vision_config.to_dict(), + vision_config.to_dict(), + msg=f"""SigLIPConfig.to_dict():{config.to_dict()}, + SigLIPConfig.vision_config.to_dict():{config.vision_config.to_dict()}, + PretrainedConfig.to_dict():{vision_config.to_dict()}, + tmp_dir_name: {tmp_dir_name}, + """ + ) + + # Save SiglipConfig and check if we can load SiglipTextConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + text_config = SiglipTextConfig.from_pretrained(tmp_dir_name, force_download=True) + self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) + + @slow + def test_model_from_pretrained(self): + model_name = "google/siglip-base-patch16-224" + model = SiglipModel.from_pretrained(model_name, from_pt=True) + self.assertIsNotNone(model) + + @parameterized.expand([("float16",), ("float32",)]) + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, ms_dtype: str): + super().test_eager_matches_sdpa_inference( + ms_dtype=ms_dtype, + logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), + use_attention_mask_options=(False, True), + ) + + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + + +class SiglipForImageClassificationModelTester(SiglipModelTester): + def __init__(self, parent): + super().__init__(parent) + self.batch_size = self.vision_model_tester.batch_size + self.num_hidden_layers = self.vision_model_tester.num_hidden_layers + self.hidden_size = self.vision_model_tester.hidden_size + self.seq_length = self.vision_model_tester.seq_length + + def prepare_config_and_inputs(self): + _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_mindspore +class SiglipForImageClassificationModelTest(SiglipModelTesterMixin, unittest.TestCase): + all_model_classes = (SiglipForImageClassification,) if is_mindspore_available() else () + pipeline_model_mapping = {"image-classification": SiglipForImageClassification} if is_mindspore_available() else {} + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + # MP works but offload doesn't work when the MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + _is_composite = True + + def setUp(self): + self.model_tester = SiglipForImageClassificationModelTester(self) + + @unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + def test_initialization(self): + pass + + @parameterized.expand([("float16",), ("float32",)]) + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, ms_dtype: str): + super().test_eager_matches_sdpa_inference( + ms_dtype=ms_dtype, logit_keys=("logits",), use_attention_mask_options=(False,) + ) + + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + return image + + +@require_vision +@require_mindspore +class SiglipModelIntegrationTest(unittest.TestCase): + @slow + def test_inference(self): + model_name = "google/siglip-base-patch16-224" + model = SiglipModel.from_pretrained(model_name, from_pt=True) + processor = SiglipProcessor.from_pretrained(model_name, from_pt=True) + + image = prepare_img() + inputs = processor( + text=["a photo of 2 cats", "a photo of 2 dogs"], images=image, padding="max_length", return_tensors="ms" + ) + + # forward pass + outputs = model(**inputs) + logits_per_image = outputs.logits_per_image + logits_per_text = outputs.logits_per_text + + # verify the logits + self.assertEqual( + logits_per_image.shape, + (inputs.pixel_values.shape[0], inputs.input_ids.shape[0]), + ) + self.assertEqual( + logits_per_text.shape, + (inputs.input_ids.shape[0], inputs.pixel_values.shape[0]), + ) + + expected_logits = mindspore.tensor([[-0.7567, -10.3354]]) + + self.assertTrue(np.allclose(outputs.logits_per_image.asnumpy(), expected_logits.asnumpy(), atol=1e-1)) + + # verify the probs + probs = ops.sigmoid(logits_per_image) # these are the probabilities + expected_probs = mindspore.tensor([[3.1937e-01, 3.2463e-05]]) + self.assertTrue(np.allclose(probs.asnumpy(), expected_probs.asnumpy(), atol=1e-1)) + + @slow + def test_inference_interpolate_pos_encoding(self): + model_name = "google/siglip-base-patch16-224" + model = SiglipModel.from_pretrained(model_name, from_pt=True) + + # 640 x 480 image + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + processor = SiglipProcessor.from_pretrained(model_name, do_resize=False, size={"height": 480, "width": 640}, from_pt=True) + + inputs = processor(text="what's in the image", images=image, return_tensors="ms") + + # forward pass + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the shape + # patch size = 16 + # batch size 1, (640/16) * (480/16) = 1200 patches, 768 hidden size + expected_shape = (1, 1200, 768) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) diff --git a/tests/transformers/models/siglip/test_tokenization_siglip.py b/tests/transformers/models/siglip/test_tokenization_siglip.py new file mode 100644 index 000000000..08ea3e56a --- /dev/null +++ b/tests/transformers/models/siglip/test_tokenization_siglip.py @@ -0,0 +1,455 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +import unittest + +from mindnlp.transformers import SPIECE_UNDERLINE +from mindnlp.transformers.tokenization_utils import AddedToken, BatchEncoding +from mindnlp.transformers.models.siglip.tokenization_siglip import SiglipTokenizer +from mindnlp.utils import cached_property, is_mindspore_available +from mindnlp.utils.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow + +from ...test_tokenization_common import TokenizerTesterMixin + + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + +if is_mindspore_available(): + FRAMEWORK = "ms" +else: + FRAMEWORK = "jax" + + +@require_sentencepiece +@require_tokenizers +class SiglipTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + from_pretrained_id = "google/siglip-base-patch16-224" + tokenizer_class = SiglipTokenizer + test_rust_tokenizer = False + test_sentencepiece = True + test_sentencepiece_ignore_case = True + + # Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.setUp with T5->Siglip + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = SiglipTokenizer(SAMPLE_VOCAB) + tokenizer.save_pretrained(self.tmpdirname) + + # Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_convert_token_and_id with T5->Siglip + def test_convert_token_and_id(self): + """Test ``_convert_token_to_id`` and ``_convert_id_to_token``.""" + token = "" + token_id = 1 + + self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id) + self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token) + + def test_get_vocab(self): + vocab_keys = list(self.get_tokenizer().get_vocab().keys()) + + self.assertEqual(vocab_keys[0], "") + self.assertEqual(vocab_keys[1], "") + + def test_full_tokenizer(self): + tokenizer = SiglipTokenizer(SAMPLE_VOCAB) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁this", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [66, 46, 10, 170, 382]) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE, + "i", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual(ids, [7, 23, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 12, 66, 46, 72, 80, 6, 0]) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE, + "i", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ], + ) + + @cached_property + def siglip_tokenizer(self): + return SiglipTokenizer.from_pretrained("google/siglip-base-patch16-224") + + # Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.get_tokenizer with T5->Siglip + def get_tokenizer(self, **kwargs) -> SiglipTokenizer: + return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) + + # Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_rust_and_python_full_tokenizers with T5->Siglip + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + self.skipTest(reason="test_rust_tokenizer is set to False") + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + sequence = "I was born in 92000, and this is falsé." + + tokens = tokenizer.tokenize(sequence) + rust_tokens = rust_tokenizer.tokenize(sequence) + self.assertListEqual(tokens, rust_tokens) + + ids = tokenizer.encode(sequence, add_special_tokens=False) + rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False) + self.assertListEqual(ids, rust_ids) + + rust_tokenizer = self.get_rust_tokenizer() + ids = tokenizer.encode(sequence) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + def test_eos_treatment(self): + tokenizer = self.siglip_tokenizer + batch_with_eos_added = tokenizer(["hi", "I went to the gym", ""]) + batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""]) + self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"]) + + def test_prepare_batch(self): + tokenizer = self.siglip_tokenizer + src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] + expected_src_tokens = [262, 266, 476, 8532, 270, 4460, 3949, 1682, tokenizer.eos_token_id] + batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + self.assertIsInstance(batch, BatchEncoding) + + if FRAMEWORK != "jax": + result = list(batch.input_ids.numpy()[0]) + else: + result = list(batch.input_ids.tolist()[0]) + + self.assertListEqual(expected_src_tokens, result) + + self.assertEqual((2, 9), batch.input_ids.shape) + + def test_empty_target_text(self): + tokenizer = self.siglip_tokenizer + src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] + batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + # check if input_ids are returned and no decoder_input_ids + self.assertIn("input_ids", batch) + self.assertNotIn("decoder_input_ids", batch) + self.assertNotIn("decoder_attention_mask", batch) + + def test_max_length(self): + tokenizer = self.siglip_tokenizer + tgt_text = ["Summary of the text.", "Another summary."] + targets = tokenizer( + text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK + ) + self.assertEqual(32, targets["input_ids"].shape[1]) + + def test_eos_in_input(self): + tokenizer = self.siglip_tokenizer + src_text = ["A long paragraph for summarization. "] + tgt_text = ["Summary of the text. "] + expected_src_tokens = [262, 266, 476, 8532, 270, 4460, 3949, 1682, 1] + expected_tgt_tokens = [6254, 267, 260, 1443, 1] + + batch = tokenizer(src_text, text_target=tgt_text) + + self.assertEqual(expected_src_tokens, batch["input_ids"][0]) + self.assertEqual(expected_tgt_tokens, batch["labels"][0]) + + @unittest.skip(reason="SiglipTokenizer strips the punctuation") + def test_subword_regularization_tokenizer(self): + pass + + @unittest.skip(reason="SiglipTokenizer strips the punctuation") + def test_pickle_subword_regularization_tokenizer(self): + pass + + # Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_special_tokens_initialization with T5->Siglip + def test_special_tokens_initialization(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + added_tokens = [f"" for i in range(100)] + [AddedToken("", lstrip=True)] + + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + tokenizer_cr = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True + ) + tokenizer_p = self.tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + + p_output = tokenizer_p.encode("Hey this is a token") + r_output = tokenizer_r.encode("Hey this is a token") + cr_output = tokenizer_cr.encode("Hey this is a token") + + special_token_id = tokenizer_r.encode("", add_special_tokens=False)[0] + + self.assertEqual(p_output, r_output) + self.assertEqual(cr_output, r_output) + self.assertTrue(special_token_id in p_output) + self.assertTrue(special_token_id in r_output) + self.assertTrue(special_token_id in cr_output) + + # Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_special_tokens_initialization_with_non_empty_additional_special_tokens with T5->Siglip + def test_special_tokens_initialization_with_non_empty_additional_special_tokens(self): + tokenizer_list = [] + if self.test_slow_tokenizer: + tokenizer_list.append((self.tokenizer_class, self.get_tokenizer())) + + if self.test_rust_tokenizer: + tokenizer_list.append((self.rust_tokenizer_class, self.get_rust_tokenizer())) + + for tokenizer_class, tokenizer_utils in tokenizer_list: + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer_utils.save_pretrained(tmp_dir) + + with open(os.path.join(tmp_dir, "special_tokens_map.json"), encoding="utf-8") as json_file: + special_tokens_map = json.load(json_file) + + with open(os.path.join(tmp_dir, "tokenizer_config.json"), encoding="utf-8") as json_file: + tokenizer_config = json.load(json_file) + + added_tokens_extra_ids = [f"" for i in range(100)] + + special_tokens_map["additional_special_tokens"] = added_tokens_extra_ids + [ + "an_additional_special_token" + ] + tokenizer_config["additional_special_tokens"] = added_tokens_extra_ids + [ + "an_additional_special_token" + ] + + with open(os.path.join(tmp_dir, "special_tokens_map.json"), "w", encoding="utf-8") as outfile: + json.dump(special_tokens_map, outfile) + with open(os.path.join(tmp_dir, "tokenizer_config.json"), "w", encoding="utf-8") as outfile: + json.dump(tokenizer_config, outfile) + + # the following checks allow us to verify that our test works as expected, i.e. that the tokenizer takes + # into account the new value of additional_special_tokens given in the "tokenizer_config.json" and + # "special_tokens_map.json" files + tokenizer_without_change_in_init = tokenizer_class.from_pretrained( + tmp_dir, + ) + self.assertIn( + "an_additional_special_token", tokenizer_without_change_in_init.additional_special_tokens + ) + # self.assertIn("an_additional_special_token",tokenizer_without_change_in_init.get_vocab()) # BySiglipTokenization no vocab + self.assertEqual( + ["an_additional_special_token"], + tokenizer_without_change_in_init.convert_ids_to_tokens( + tokenizer_without_change_in_init.convert_tokens_to_ids(["an_additional_special_token"]) + ), + ) + + # Now we test that we can change the value of additional_special_tokens in the from_pretrained + new_added_tokens = added_tokens_extra_ids + [AddedToken("a_new_additional_special_token", lstrip=True)] + tokenizer = tokenizer_class.from_pretrained( + tmp_dir, + additional_special_tokens=new_added_tokens, + ) + + self.assertIn("a_new_additional_special_token", tokenizer.additional_special_tokens) + self.assertEqual( + ["a_new_additional_special_token"], + tokenizer.convert_ids_to_tokens( + tokenizer.convert_tokens_to_ids(["a_new_additional_special_token"]) + ), + ) + + def test_sentencepiece_tokenize_and_convert_tokens_to_string(self): + """Test ``_tokenize`` and ``convert_tokens_to_string``.""" + if not self.test_sentencepiece: + self.skipTest(reason="test_sentencepiece is set to False") + + tokenizer = self.get_tokenizer() + text = "This is text to test the tokenizer." + + if self.test_sentencepiece_ignore_case: + text = text.lower() + + tokens = tokenizer.tokenize(text) + + self.assertTrue(len(tokens) > 0) + + # check if converting back to original text works + reverse_text = tokenizer.convert_tokens_to_string(tokens) + + if self.test_sentencepiece_ignore_case: + reverse_text = reverse_text.lower() + + expected_text = "this is text to test the tokenizer" + self.assertEqual(reverse_text, expected_text) + + special_tokens = tokenizer.all_special_tokens + special_tokens_string = tokenizer.convert_tokens_to_string(special_tokens) + for special_token in special_tokens: + self.assertIn(special_token, special_tokens_string) + + if self.test_rust_tokenizer: + rust_tokenizer = self.get_rust_tokenizer() + special_tokens_string_rust = rust_tokenizer.convert_tokens_to_string(special_tokens) + self.assertEqual(special_tokens_string, special_tokens_string_rust) + + @slow + def test_tokenizer_integration(self): + tokenizer = SiglipTokenizer.from_pretrained("google/siglip-base-patch16-224") + + # fmt: off + texts = [ + 'the real mountain view', + 'Zürich', + 'San Francisco', + 'a picture of a laptop with the lockscreen on, a cup of cappucino, salt and pepper grinders. The view through the window reveals lake Zürich and the Alps in the background of the city.', + ] + + expected_input_ids = [ + [260, 638, 3293, 870, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [262, 761, 5879, 5345, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [262, 264, 452, 20563, 15949, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [262, 266, 1357, 267, 262, 266, 4429, 275, 260, 3940, 6360, 277, 262, 266, 3064, 267, 3549, 388, 16538, 296, 298, 2617, 263, 4869, 14998, 264, 260, 870, 393, 260, 1710, 7958, 4324, 262, 761, 5879, 5345, 263, 260, 1518, 388, 264, 268, 260, 1970, 267, 260, 741, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ] + # fmt: on + + for text, expected in zip(texts, expected_input_ids): + input_ids = tokenizer(text, padding="max_length").input_ids + self.assertListEqual(input_ids, expected) + + def test_some_edge_cases(self): + tokenizer = SiglipTokenizer.from_pretrained("google/siglip-base-patch16-224", legacy=False) + + sp_tokens = tokenizer.sp_model.encode(">", out_type=str) + self.assertEqual(sp_tokens, ["", ">"]) + tokens = tokenizer.tokenize(">") + self.assertNotEqual(sp_tokens, tokens) + self.assertEqual(tokens, [""]) + + tokens = tokenizer.tokenize("") + self.assertEqual(tokens, []) + self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str)) + + tokens = tokenizer.tokenize(" ") + self.assertEqual(tokens, []) + self.assertEqual(tokens, tokenizer.sp_model.encode(" ", out_type=str)) + + tokens = tokenizer.tokenize("▁") + self.assertEqual(tokens, []) + self.assertEqual(tokens, tokenizer.sp_model.encode("▁", out_type=str)) + + tokens = tokenizer.tokenize(" ▁") + self.assertEqual(tokens, []) + self.assertEqual(tokens, tokenizer.sp_model.encode("▁", out_type=str)) + + +@require_sentencepiece +@require_tokenizers +class CommonSpmIntegrationTests(unittest.TestCase): + """ + A class that regroups important test to make sure that we properly handle the special tokens. + """ + + @classmethod + def setUpClass(cls): + tokenizer = SiglipTokenizer(SAMPLE_VOCAB, extra_ids=0, legacy=False) + tokenizer.add_special_tokens( + {"additional_special_tokens": [AddedToken("", rstrip=False, lstrip=False)]} + ) + cls.tokenizer = tokenizer + + def test_add_dummy_prefix(self): + # make sure `'▁'` is prepended, and outputs match sp_model's + # `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute + input_ids = self.tokenizer.encode(". Hello", add_special_tokens=False) + self.assertEqual(input_ids, [37, 86, 20]) + self.assertEqual(input_ids, [37, 86, 20]) + tokens = self.tokenizer.tokenize(". Hello") + self.assertEqual(tokens, ["▁he", "ll", "o"]) + + tokens = self.tokenizer.tokenize("") + self.assertEqual(tokens, []) + self.assertEqual(tokens, self.tokenizer.sp_model.encode("", out_type=str)) + + tokens = self.tokenizer.tokenize(" ") + self.assertEqual(tokens, []) + self.assertEqual(tokens, self.tokenizer.sp_model.encode(" ", out_type=str)) + + tokens = self.tokenizer.tokenize("▁") + self.assertEqual(tokens, []) + self.assertEqual(tokens, self.tokenizer.sp_model.encode("▁", out_type=str)) + + def test_remove_extra_whitespaces(self): + # make sure the extra spaces are eaten + # sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute + input_ids = self.tokenizer.encode(" . Hello", add_special_tokens=False) + self.assertEqual(input_ids, [37, 86, 20]) + self.assertEqual(input_ids, [37, 86, 20]) + tokens = self.tokenizer.tokenize(" . Hello") + self.assertEqual(tokens, ["▁he", "ll", "o"]) + + # `'▁'` is also a whitespace + input_ids = self.tokenizer.encode("▁He is not") + self.assertEqual(input_ids, [37, 46, 44, 2]) + tokens = self.tokenizer.tokenize("▁He is not") + self.assertEqual(tokens, ["▁he", "▁is", "▁not"]) # no extra space added + + input_ids = self.tokenizer.encode("▁He is not ▁He") + self.assertEqual(input_ids, [37, 46, 44, 37, 2]) + tokens = self.tokenizer.tokenize("▁He is not ▁He") + self.assertEqual(tokens, ["▁he", "▁is", "▁not", "▁he"]) # spaces are eaten by spm even if not start