diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 648877c8dce9..dfb64fcd0869 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -994,8 +994,11 @@ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* string, which can then be stored in the json format. """ - if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): - d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + if d.get("torch_dtype", None) is not None: + if isinstance(d["torch_dtype"], dict): + d["torch_dtype"] = {k: str(v).split(".")[-1] for k, v in d["torch_dtype"].items()} + elif not isinstance(d["torch_dtype"], str): + d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] for value in d.values(): if isinstance(value, dict): self.dict_torch_dtype_to_str(value) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8eb2d7439ef3..c09c11050041 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1312,11 +1312,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): "`PretrainedConfig`. To create a model from a pretrained model use " f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" ) - # Save config and origin of the pretrained weights if given in model if not getattr(config, "_attn_implementation_autoset", False): - config = self._autoset_attn_implementation( - config, torch_dtype=torch.get_default_dtype(), check_device_map=False - ) + # config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests + dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype() + config = self._autoset_attn_implementation(config, torch_dtype=dtype, check_device_map=False) self.config = config # for initialization of the loss @@ -1411,7 +1410,10 @@ def _from_config(cls, config, **kwargs): # when we init a model from within another model (e.g. VLMs) and dispatch on FA2 # a warning is raised that dtype should be fp16. Since we never pass dtype from within # modeling code, we can try to infer it here same way as done in `from_pretrained` - torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype()) + torch_dtype = kwargs.pop("torch_dtype", config.torch_dtype) + if isinstance(torch_dtype, str): + torch_dtype = getattr(torch, torch_dtype) + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) # override default dtype if needed @@ -4020,11 +4022,37 @@ def from_pretrained( ) elif hasattr(torch, torch_dtype): torch_dtype = getattr(torch, torch_dtype) - else: - raise ValueError( - f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}' - ) + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype + elif isinstance(torch_dtype, torch.dtype): + pass + elif isinstance(torch_dtype, dict): + for key, curr_dtype in torch_dtype.items(): + if hasattr(config, key): + value = getattr(config, key) + value.torch_dtype = curr_dtype + # main torch dtype for modules that aren't part of any sub-config + torch_dtype = torch_dtype.get("") + config.torch_dtype = torch_dtype + if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + elif torch_dtype is None: + torch_dtype = torch.float32 + else: + raise ValueError( + f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` " + f"for each sub-config in composite configs, but received {torch_dtype}" + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + else: + # set fp32 as the default dtype for BC + default_dtype = str(torch.get_default_dtype()).split(".")[-1] + config.torch_dtype = default_dtype + for key in config.sub_configs.keys(): + value = getattr(config, key) + value.torch_dtype = default_dtype # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 90a02dd5bb9f..edbac91bb060 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -967,62 +967,6 @@ def forward(self, pixel_values: torch.LongTensor): return last_hidden_state -CHAMELEON_VQ_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`ChameleonVQVAEConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens. - This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from - [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). - """, - CHAMELEON_VQ_START_DOCSTRING, -) -class ChameleonVQVAE(PreTrainedModel): - config_class = ChameleonVQVAEConfig - _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - elif isinstance(module, nn.GroupNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - - def __init__(self, config: ChameleonVQVAEConfig): - super().__init__(config) - - self.encoder = ChameleonVQVAEEncoder(config) - self.quantize = ChameleonVQVAEVectorQuantizer(config) - self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) - self.eval() # Chameleon's VQ model is frozen - - def encode(self, pixel_values: torch.LongTensor): - hidden_states = self.encoder(pixel_values) - hidden_states = self.quant_conv(hidden_states) - quant, emb_loss, indices = self.quantize(hidden_states) - return quant, emb_loss, indices - - class ChameleonImageVocabularyMapping: """ A class for mapping discrete image tokens from VQGAN to BPE tokens. @@ -1118,6 +1062,62 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() +CHAMELEON_VQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ChameleonVQVAEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). + """, + CHAMELEON_VQ_START_DOCSTRING, +) +class ChameleonVQVAE(ChameleonPreTrainedModel): + config_class = ChameleonVQVAEConfig + _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.GroupNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__(config) + + self.encoder = ChameleonVQVAEEncoder(config) + self.quantize = ChameleonVQVAEVectorQuantizer(config) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) + self.eval() # Chameleon's VQ model is frozen + + def encode(self, pixel_values: torch.LongTensor): + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + CHAMELEON_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1211,7 +1211,7 @@ def __init__(self, config: ChameleonConfig): [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.vqmodel = ChameleonVQVAE(config.vq_config) + self.vqmodel = ChameleonVQVAE._from_config(config.vq_config) self.gradient_checkpointing = False # Initialize weights and apply final processing diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 08d9eddd9e2f..c81fbbd013a8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -727,7 +727,7 @@ def __init__(self, config): super().__init__(config) self.model = PhiModel(config) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 0faa4629f1a7..d8480a7ad61a 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -284,7 +284,9 @@ def forward( class PhiForCausalLM(LlamaForCausalLM): - pass + def __init__(self, config): + super().__init__(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) class PhiForSequenceClassification(LlamaForSequenceClassification): diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 74a3bfe04b75..f5af373f49bc 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -32,7 +32,6 @@ require_accelerate, require_fsdp, require_torch_accelerator, - require_torch_gpu, require_torch_multi_accelerator, slow, torch_device, @@ -288,7 +287,7 @@ def test_training_and_can_resume_normally(self, state_dict_type): @require_torch_multi_accelerator @slow - @require_torch_gpu + @require_torch_accelerator @require_fsdp def test_fsdp_cpu_offloading(self): try: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 510f3fe1a926..7499a5599b7c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -33,6 +33,7 @@ require_flash_attn, require_optimum_quanto, require_torch, + require_torch_accelerator, require_torch_gpu, require_torch_multi_accelerator, require_torch_multi_gpu, @@ -2042,16 +2043,10 @@ def test_generate_with_quant_cache(self): with self.assertRaises(ValueError): model.generate(**generation_kwargs, **inputs_dict) - @parameterized.expand( - [ - ("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow" - ("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix - ] - ) @pytest.mark.generate - @require_torch_gpu + @require_torch_accelerator @slow - def test_generate_compile(self, _, end_to_end): + def test_generate_compile_model_forward(self): """ Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests end-to-end compilation and forward pass compilation only. @@ -2061,14 +2056,7 @@ def test_generate_compile(self, _, end_to_end): if not model_class._supports_static_cache: self.skipTest("This model doesn't support static cache") - # TODO (joao) -- fix and enable me :) - if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]): - self.skipTest("whisper model end-to-end generate compile not yet supported") - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - # TODO (joao) -- fix and enable me :) - if end_to_end and config.is_encoder_decoder: - self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported") model = model_class(config).to(torch_device) model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time @@ -2084,10 +2072,8 @@ def test_generate_compile(self, _, end_to_end): "max_new_tokens": 10, "return_dict_in_generate": True, "output_scores": True, + "cache_implementation": "static", } - # end-to-end works best with dynamic cache, forward compilation works best with static cache - if not end_to_end: - generation_kwargs["cache_implementation"] = "static" # get eager + dynamic cache results for future comparison dynamic_outputs = [] @@ -2098,10 +2084,8 @@ def test_generate_compile(self, _, end_to_end): generation_config = copy.deepcopy(model.generation_config) generation_config.update(**generation_kwargs) torch.compiler.reset() - if end_to_end: - model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") - else: - model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") + + model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") compiled_outputs = [] for model_inputs in input_ids_sets: @@ -3808,10 +3792,12 @@ def test_assisted_decoding_in_different_gpu(self): self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) @slow - @require_torch_gpu + @require_torch_accelerator def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self): # PT-only test: TF doesn't support assisted decoding yet. - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + torch_device + ) assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( "cpu" ) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index a1ea708efd66..5e18b006a5d8 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -27,6 +27,7 @@ from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers.testing_utils import ( require_torch, + require_torch_accelerator, require_torch_fp16, require_torch_gpu, require_torch_multi_accelerator, @@ -1565,7 +1566,7 @@ def test_forward_signature(self): self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) @slow - @require_torch_gpu + @require_torch_accelerator def test_model_from_pretrained(self): model_name = "Salesforce/blip2-itm-vit-g" model = Blip2TextModelWithProjection.from_pretrained(model_name) @@ -2191,7 +2192,7 @@ def test_expansion_in_processing(self): self.assertTrue(generated_text_expanded == generated_text) - @require_torch_gpu + @require_torch_accelerator def test_inference_itm(self): model_name = "Salesforce/blip2-itm-vit-g" processor = Blip2Processor.from_pretrained(model_name) @@ -2210,7 +2211,7 @@ def test_inference_itm(self): self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3)) self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3)) - @require_torch_gpu + @require_torch_accelerator @require_torch_fp16 def test_inference_itm_fp16(self): model_name = "Salesforce/blip2-itm-vit-g" @@ -2232,7 +2233,7 @@ def test_inference_itm_fp16(self): ) self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3)) - @require_torch_gpu + @require_torch_accelerator @require_torch_fp16 def test_inference_vision_with_projection_fp16(self): model_name = "Salesforce/blip2-itm-vit-g" @@ -2256,7 +2257,7 @@ def test_inference_vision_with_projection_fp16(self): ] self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3)) - @require_torch_gpu + @require_torch_accelerator @require_torch_fp16 def test_inference_text_with_projection_fp16(self): model_name = "Salesforce/blip2-itm-vit-g" diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index bb2ba8b34281..01d4ef720e57 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -333,7 +333,7 @@ def test_batching_equivalence(self): # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow @unittest.skip("Chameleon is not compatible with end-to-end generation compilation") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass diff --git a/tests/models/dbrx/test_modeling_dbrx.py b/tests/models/dbrx/test_modeling_dbrx.py index d38a479ab36e..dee93109da24 100644 --- a/tests/models/dbrx/test_modeling_dbrx.py +++ b/tests/models/dbrx/test_modeling_dbrx.py @@ -369,7 +369,7 @@ def test_disk_offload_bin(self): pass @unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py index 9e2f71174865..64dfb5b64955 100644 --- a/tests/models/diffllama/test_modeling_diffllama.py +++ b/tests/models/diffllama/test_modeling_diffllama.py @@ -676,7 +676,7 @@ def test_eager_matches_sdpa_generate(self): ) -@require_torch_gpu +@require_torch_accelerator class DiffLlamaIntegrationTest(unittest.TestCase): # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) # Depending on the hardware we get different logits / generations @@ -689,7 +689,7 @@ def setUpClass(cls): cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] @slow - @require_torch_gpu + @require_torch_accelerator @require_read_token def test_compile_static_cache(self): # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py index d1c4501c5e8b..007207c06942 100644 --- a/tests/models/emu3/test_modeling_emu3.py +++ b/tests/models/emu3/test_modeling_emu3.py @@ -176,10 +176,6 @@ def test_model_rope_scaling(self, scaling_type): def test_custom_4d_attention_mask(self): pass - @unittest.skip("Fails with unknown error only on end-to-end compile") # TODO raushan fixme - def test_generate_compile_1_end_to_end(self): - pass - class Emu3Vision2TextModelTester: def __init__( @@ -398,10 +394,6 @@ def test_custom_4d_attention_mask(self): def test_initialization(self): pass - @unittest.skip("End-to-end compilation is not supported due to dynamic control in `prepare_inputs_for_generation`") - def test_generate_compile_1_end_to_end(self): - pass - @require_torch class Emu3IntegrationTest(unittest.TestCase): diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index f02e8f167636..eb1205db9cc1 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -23,7 +23,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_torch, - require_torch_gpu, + require_torch_accelerator, require_torch_multi_gpu, slow, torch_device, @@ -426,7 +426,7 @@ def recursive_check(tuple_object, dict_object): @require_torch -@require_torch_gpu +@require_torch_accelerator @slow class FalconMambaIntegrationTests(unittest.TestCase): def setUp(self): diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py index bcac135be721..0444ad14f269 100644 --- a/tests/models/fuyu/test_modeling_fuyu.py +++ b/tests/models/fuyu/test_modeling_fuyu.py @@ -22,7 +22,7 @@ from parameterized import parameterized from transformers import FuyuConfig, is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device from transformers.utils import cached_property from ...generation.test_utils import GenerationTesterMixin @@ -327,7 +327,7 @@ def test_model_parallelism(self): @slow -@require_torch_gpu +@require_torch_accelerator class FuyuModelIntegrationTest(unittest.TestCase): @cached_property def default_processor(self): diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 94229b13d2cb..a8f1304b6fc7 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -781,7 +781,7 @@ def test_custom_4d_attention_mask(self): pass @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass @unittest.skip(reason="We only test the model that takes in multiple images") diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index feca640bb4a1..664616306d88 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -26,7 +26,6 @@ require_read_token, require_torch, require_torch_accelerator, - require_torch_gpu, slow, torch_device, ) @@ -541,7 +540,7 @@ def _reinitialize_config(base_config, new_kwargs): config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" -@require_torch_gpu +@require_torch_accelerator class LlamaIntegrationTest(unittest.TestCase): # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) # Depending on the hardware we get different logits / generations @@ -695,7 +694,7 @@ def test_model_7b_dola_generation(self): self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @slow - @require_torch_gpu + @require_torch_accelerator @require_read_token def test_compile_static_cache(self): # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index d9e6b9d7bfe7..70de4d9cf1ed 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -424,7 +424,7 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="Mistral flash attention does not support right padding") -@require_torch_gpu +@require_torch_accelerator class MistralIntegrationTest(unittest.TestCase): # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) # Depending on the hardware we get different logits / generations diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 9abbf444d0b0..cf192b8bd79e 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -22,6 +22,7 @@ from transformers.testing_utils import ( require_flash_attn, require_torch, + require_torch_accelerator, require_torch_gpu, slow, torch_device, @@ -471,7 +472,7 @@ def setUpClass(cls): cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] @slow - @require_torch_gpu + @require_torch_accelerator def test_small_model_logits(self): model_id = "hf-internal-testing/Mixtral-tiny" dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device) @@ -507,7 +508,7 @@ def test_small_model_logits(self): ) @slow - @require_torch_gpu + @require_torch_accelerator def test_small_model_logits_batched(self): model_id = "hf-internal-testing/Mixtral-tiny" dummy_input = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3], [1, 1, 2, 3, 4, 5, 6, 7, 8]]).to(torch_device) diff --git a/tests/models/nemotron/test_modeling_nemotron.py b/tests/models/nemotron/test_modeling_nemotron.py index fd62c74d3d6e..249706c1c470 100644 --- a/tests/models/nemotron/test_modeling_nemotron.py +++ b/tests/models/nemotron/test_modeling_nemotron.py @@ -26,6 +26,7 @@ require_flash_attn, require_read_token, require_torch, + require_torch_accelerator, require_torch_gpu, require_torch_sdpa, slow, @@ -103,7 +104,7 @@ def test_model_outputs_equivalence(self, **kwargs): pass @require_torch_sdpa - @require_torch_gpu + @require_torch_accelerator @slow def test_sdpa_equivalence(self): for model_class in self.all_model_classes: diff --git a/tests/models/omdet_turbo/test_modeling_omdet_turbo.py b/tests/models/omdet_turbo/test_modeling_omdet_turbo.py index ed85c4c00078..75c0e6f1c78d 100644 --- a/tests/models/omdet_turbo/test_modeling_omdet_turbo.py +++ b/tests/models/omdet_turbo/test_modeling_omdet_turbo.py @@ -26,7 +26,7 @@ from transformers.testing_utils import ( require_timm, require_torch, - require_torch_gpu, + require_torch_accelerator, require_vision, slow, torch_device, @@ -865,7 +865,7 @@ def test_inference_object_detection_head_batched(self): ] self.assertListEqual([result["classes"] for result in results], expected_classes) - @require_torch_gpu + @require_torch_accelerator def test_inference_object_detection_head_equivalence_cpu_gpu(self): processor = self.default_processor image = prepare_img() @@ -878,8 +878,8 @@ def test_inference_object_detection_head_equivalence_cpu_gpu(self): cpu_outputs = model(**encoding) # 2. run model on GPU - model.to("cuda") - encoding = encoding.to("cuda") + model.to(torch_device) + encoding = encoding.to(torch_device) with torch.no_grad(): gpu_outputs = model(**encoding) diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index f973e1211dc0..587d4606493b 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -348,7 +348,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow @unittest.skip("PaliGemma is not compatible with end-to-end generation compilation") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index aedd37992632..fc9adcebf5f7 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -227,6 +227,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration} test_pruning = False test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = Qwen2VLVisionText2TextModelTester(self) @@ -332,7 +333,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self): pass @unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`") - def test_generate_compile_fullgraph(self): + def test_generate_compile_model_forward(self): pass diff --git a/tests/models/rt_detr/test_modeling_rt_detr.py b/tests/models/rt_detr/test_modeling_rt_detr.py index 65a417fe56f6..368e2dd140f3 100644 --- a/tests/models/rt_detr/test_modeling_rt_detr.py +++ b/tests/models/rt_detr/test_modeling_rt_detr.py @@ -28,7 +28,13 @@ is_torch_available, is_vision_available, ) -from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property from ...test_configuration_common import ConfigTester @@ -631,7 +637,7 @@ def test_initialization(self): self.assertTrue(not failed_cases, message) @parameterized.expand(["float32", "float16", "bfloat16"]) - @require_torch_gpu + @require_torch_accelerator @slow def test_inference_with_different_dtypes(self, torch_dtype_str): torch_dtype = { @@ -653,7 +659,7 @@ def test_inference_with_different_dtypes(self, torch_dtype_str): _ = model(**self._prepare_for_class(inputs_dict, model_class)) @parameterized.expand(["float32", "float16", "bfloat16"]) - @require_torch_gpu + @require_torch_accelerator @slow def test_inference_equivalence_for_static_and_dynamic_anchors(self, torch_dtype_str): torch_dtype = { diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index df743f132c11..d6993469e043 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -23,6 +23,7 @@ require_bitsandbytes, require_flash_attn, require_torch, + require_torch_accelerator, require_torch_gpu, slow, torch_device, @@ -412,7 +413,7 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): @slow -@require_torch_gpu +@require_torch_accelerator class Starcoder2IntegrationTest(unittest.TestCase): def test_starcoder2_batched_generation_sdpa(self): EXPECTED_TEXT = [ diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index b03416390766..52fec78d1e89 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -27,7 +27,7 @@ require_sentencepiece, require_tokenizers, require_torch, - require_torch_gpu, + require_torch_accelerator, slow, torch_device, ) @@ -1646,7 +1646,7 @@ def test_contrastive_search_t5(self): ) @slow - @require_torch_gpu + @require_torch_accelerator def test_compile_static_cache(self): NUM_TOKENS_TO_GENERATE = 40 EXPECTED_TEXT_COMPLETION = [ @@ -1686,7 +1686,7 @@ def test_compile_static_cache(self): self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) @slow - @require_torch_gpu + @require_torch_accelerator def test_compile_static_cache_encoder(self): prompts = [ "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index d5014586b331..7504ae009d05 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -28,7 +28,6 @@ require_tf, require_torch, require_torch_accelerator, - require_torch_gpu, require_torch_or_tf, torch_device, ) @@ -553,7 +552,7 @@ def run_pipeline_test(self, text_generator, _): @require_torch @require_accelerate - @require_torch_gpu + @require_torch_accelerator def test_small_model_pt_bloom_accelerate(self): import torch diff --git a/tests/quantization/quanto_integration/test_quanto.py b/tests/quantization/quanto_integration/test_quanto.py index 08cc48d0cccd..2022c3366576 100644 --- a/tests/quantization/quanto_integration/test_quanto.py +++ b/tests/quantization/quanto_integration/test_quanto.py @@ -21,6 +21,7 @@ require_accelerate, require_optimum_quanto, require_read_token, + require_torch_accelerator, require_torch_gpu, slow, torch_device, @@ -123,7 +124,7 @@ def test_conversion_with_modules_to_not_convert(self): @slow -@require_torch_gpu +@require_torch_accelerator @require_optimum_quanto @require_accelerate class QuantoQuantizationTest(unittest.TestCase): @@ -268,7 +269,7 @@ def test_compare_with_quanto(self): quantize(model.transformer, weights=w_mapping[self.weights]) freeze(model.transformer) self.check_same_model(model, self.quantized_model) - self.check_inference_correctness(model, device="cuda") + self.check_inference_correctness(model, device=torch_device) @unittest.skip def test_load_from_quanto_saved(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6a9b8523f9e4..0d12bf77d861 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1862,7 +1862,6 @@ def test_resize_position_vector_embeddings(self): def test_resize_tokens_embeddings(self): if not self.test_resize_embeddings: self.skipTest(reason="test_resize_embeddings is set to `False`") - ( original_config, inputs_dict, @@ -2017,7 +2016,7 @@ def test_resize_tokens_embeddings(self): torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1) @require_deepspeed - @require_torch_gpu + @require_torch_accelerator def test_resize_tokens_embeddings_with_deepspeed(self): ds_config = { "zero_optimization": { @@ -2123,7 +2122,7 @@ def test_resize_embeddings_untied(self): model(**self._prepare_for_class(inputs_dict, model_class)) @require_deepspeed - @require_torch_gpu + @require_torch_accelerator def test_resize_embeddings_untied_with_deepspeed(self): ds_config = { "zero_optimization": { @@ -3202,7 +3201,7 @@ def check_device_map_is_respected(self, model, device_map): @require_accelerate @mark.accelerate_tests - @require_torch_gpu + @require_torch_accelerator def test_disk_offload_bin(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -3243,7 +3242,7 @@ def test_disk_offload_bin(self): @require_accelerate @mark.accelerate_tests - @require_torch_gpu + @require_torch_accelerator def test_disk_offload_safetensors(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -3278,7 +3277,7 @@ def test_disk_offload_safetensors(self): @require_accelerate @mark.accelerate_tests - @require_torch_gpu + @require_torch_accelerator def test_cpu_offload(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -4746,7 +4745,7 @@ def test_custom_4d_attention_mask(self): torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) @slow - @require_torch_gpu + @require_torch_accelerator def test_torch_compile_for_training(self): if version.parse(torch.__version__) < version.parse("2.3"): self.skipTest(reason="This test requires torch >= 2.3 to run.") diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d89c4aa80302..6e90b3d7e405 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1831,7 +1831,7 @@ def test_adalomo(self): _ = trainer.train() @require_grokadamw - @require_torch_gpu + @require_torch_accelerator def test_grokadamw(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -1852,7 +1852,7 @@ def test_grokadamw(self): _ = trainer.train() @require_schedulefree - @require_torch_gpu + @require_torch_accelerator def test_schedulefree_adam(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 383f0cbe60e1..b8e10ff8ad4d 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -37,6 +37,7 @@ AutoModel, AutoModelForImageClassification, AutoModelForSequenceClassification, + LlavaForConditionalGeneration, OwlViTForObjectDetection, PretrainedConfig, is_torch_available, @@ -300,6 +301,7 @@ def test_local_files_only(self): TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM" TINY_IMAGE_CLASSIF = "hf-internal-testing/tiny-random-SiglipForImageClassification" +TINY_LLAVA = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration" LOG = logging.get_logger(__name__) @@ -460,6 +462,59 @@ def test_model_from_config_torch_dtype_str(self): with self.assertRaises(ValueError): model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64") + def test_model_from_config_torch_dtype_composite(self): + """ + Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config + """ + # should be able to set torch_dtype as a simple string and the model loads it correctly + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float32) + + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16") + self.assertEqual(model.language_model.dtype, torch.float16) + self.assertEqual(model.vision_tower.dtype, torch.float16) + + # should be able to set torch_dtype as a dict for each sub-config + model = LlavaForConditionalGeneration.from_pretrained( + TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"} + ) + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + + # should be able to set the values as torch.dtype (not str) + model = LlavaForConditionalGeneration.from_pretrained( + TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16} + ) + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + + # should be able to set the values in configs directly and pass it to `from_pretrained` + config = copy.deepcopy(model.config) + config.text_config.torch_dtype = torch.float32 + config.vision_config.torch_dtype = torch.bfloat16 + config.torch_dtype = torch.float16 + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.bfloat16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16) + + # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what + LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.bfloat16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32) + + # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type + with self.assertRaises(ValueError): + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64") + model = LlavaForConditionalGeneration.from_pretrained( + TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"} + ) + @require_torch def test_model_from_pretrained_meta_device(self): def is_on_meta(model_id, dtype):