Skip to content

Commit

Permalink
Merge branch 'main' into sam-rle-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
MSt-10 authored Jan 13, 2025
2 parents 53ddcab + 2fa876d commit bbed9cf
Show file tree
Hide file tree
Showing 30 changed files with 225 additions and 151 deletions.
7 changes: 5 additions & 2 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,8 +994,11 @@ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
string, which can then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
if d.get("torch_dtype", None) is not None:
if isinstance(d["torch_dtype"], dict):
d["torch_dtype"] = {k: str(v).split(".")[-1] for k, v in d["torch_dtype"].items()}
elif not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)
Expand Down
46 changes: 37 additions & 9 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,11 +1312,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
"`PretrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
# Save config and origin of the pretrained weights if given in model
if not getattr(config, "_attn_implementation_autoset", False):
config = self._autoset_attn_implementation(
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
)
# config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests
dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype()
config = self._autoset_attn_implementation(config, torch_dtype=dtype, check_device_map=False)
self.config = config

# for initialization of the loss
Expand Down Expand Up @@ -1411,7 +1410,10 @@ def _from_config(cls, config, **kwargs):
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
# modeling code, we can try to infer it here same way as done in `from_pretrained`
torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
torch_dtype = kwargs.pop("torch_dtype", config.torch_dtype)
if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype)

use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)

# override default dtype if needed
Expand Down Expand Up @@ -4020,11 +4022,37 @@ def from_pretrained(
)
elif hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
else:
raise ValueError(
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
)
for sub_config_key in config.sub_configs.keys():
sub_config = getattr(config, sub_config_key)
sub_config.torch_dtype = torch_dtype
elif isinstance(torch_dtype, torch.dtype):
pass
elif isinstance(torch_dtype, dict):
for key, curr_dtype in torch_dtype.items():
if hasattr(config, key):
value = getattr(config, key)
value.torch_dtype = curr_dtype
# main torch dtype for modules that aren't part of any sub-config
torch_dtype = torch_dtype.get("")
config.torch_dtype = torch_dtype
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
elif torch_dtype is None:
torch_dtype = torch.float32
else:
raise ValueError(
f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` "
f"for each sub-config in composite configs, but received {torch_dtype}"
)

dtype_orig = cls._set_default_torch_dtype(torch_dtype)
else:
# set fp32 as the default dtype for BC
default_dtype = str(torch.get_default_dtype()).split(".")[-1]
config.torch_dtype = default_dtype
for key in config.sub_configs.keys():
value = getattr(config, key)
value.torch_dtype = default_dtype

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
Expand Down
114 changes: 57 additions & 57 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,62 +967,6 @@ def forward(self, pixel_values: torch.LongTensor):
return last_hidden_state


CHAMELEON_VQ_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`ChameleonVQVAEConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
"""The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
""",
CHAMELEON_VQ_START_DOCSTRING,
)
class ChameleonVQVAE(PreTrainedModel):
config_class = ChameleonVQVAEConfig
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]

def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.GroupNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()

def __init__(self, config: ChameleonVQVAEConfig):
super().__init__(config)

self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen

def encode(self, pixel_values: torch.LongTensor):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
return quant, emb_loss, indices


class ChameleonImageVocabularyMapping:
"""
A class for mapping discrete image tokens from VQGAN to BPE tokens.
Expand Down Expand Up @@ -1118,6 +1062,62 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()


CHAMELEON_VQ_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`ChameleonVQVAEConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
"""The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
""",
CHAMELEON_VQ_START_DOCSTRING,
)
class ChameleonVQVAE(ChameleonPreTrainedModel):
config_class = ChameleonVQVAEConfig
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]

def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.GroupNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()

def __init__(self, config: ChameleonVQVAEConfig):
super().__init__(config)

self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen

def encode(self, pixel_values: torch.LongTensor):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
return quant, emb_loss, indices


CHAMELEON_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand Down Expand Up @@ -1211,7 +1211,7 @@ def __init__(self, config: ChameleonConfig):
[decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vqmodel = ChameleonVQVAE(config.vq_config)
self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
self.gradient_checkpointing = False

# Initialize weights and apply final processing
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def __init__(self, config):
super().__init__(config)
self.model = PhiModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

# Initialize weights and apply final processing
self.post_init()
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/phi/modular_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def forward(


class PhiForCausalLM(LlamaForCausalLM):
pass
def __init__(self, config):
super().__init__(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)


class PhiForSequenceClassification(LlamaForSequenceClassification):
Expand Down
3 changes: 1 addition & 2 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
require_accelerate,
require_fsdp,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
Expand Down Expand Up @@ -288,7 +287,7 @@ def test_training_and_can_resume_normally(self, state_dict_type):

@require_torch_multi_accelerator
@slow
@require_torch_gpu
@require_torch_accelerator
@require_fsdp
def test_fsdp_cpu_offloading(self):
try:
Expand Down
34 changes: 10 additions & 24 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
require_flash_attn,
require_optimum_quanto,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu,
Expand Down Expand Up @@ -2042,16 +2043,10 @@ def test_generate_with_quant_cache(self):
with self.assertRaises(ValueError):
model.generate(**generation_kwargs, **inputs_dict)

@parameterized.expand(
[
("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
]
)
@pytest.mark.generate
@require_torch_gpu
@require_torch_accelerator
@slow
def test_generate_compile(self, _, end_to_end):
def test_generate_compile_model_forward(self):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
end-to-end compilation and forward pass compilation only.
Expand All @@ -2061,14 +2056,7 @@ def test_generate_compile(self, _, end_to_end):
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")

# TODO (joao) -- fix and enable me :)
if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
self.skipTest("whisper model end-to-end generate compile not yet supported")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO (joao) -- fix and enable me :)
if end_to_end and config.is_encoder_decoder:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")

model = model_class(config).to(torch_device)
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
Expand All @@ -2084,10 +2072,8 @@ def test_generate_compile(self, _, end_to_end):
"max_new_tokens": 10,
"return_dict_in_generate": True,
"output_scores": True,
"cache_implementation": "static",
}
# end-to-end works best with dynamic cache, forward compilation works best with static cache
if not end_to_end:
generation_kwargs["cache_implementation"] = "static"

# get eager + dynamic cache results for future comparison
dynamic_outputs = []
Expand All @@ -2098,10 +2084,8 @@ def test_generate_compile(self, _, end_to_end):
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
torch.compiler.reset()
if end_to_end:
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
else:
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")

model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")

compiled_outputs = []
for model_inputs in input_ids_sets:
Expand Down Expand Up @@ -3808,10 +3792,12 @@ def test_assisted_decoding_in_different_gpu(self):
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)

@slow
@require_torch_gpu
@require_torch_accelerator
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
torch_device
)
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
"cpu"
)
Expand Down
11 changes: 6 additions & 5 deletions tests/models/blip_2/test_modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_torch_fp16,
require_torch_gpu,
require_torch_multi_accelerator,
Expand Down Expand Up @@ -1565,7 +1566,7 @@ def test_forward_signature(self):
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)

@slow
@require_torch_gpu
@require_torch_accelerator
def test_model_from_pretrained(self):
model_name = "Salesforce/blip2-itm-vit-g"
model = Blip2TextModelWithProjection.from_pretrained(model_name)
Expand Down Expand Up @@ -2191,7 +2192,7 @@ def test_expansion_in_processing(self):

self.assertTrue(generated_text_expanded == generated_text)

@require_torch_gpu
@require_torch_accelerator
def test_inference_itm(self):
model_name = "Salesforce/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
Expand All @@ -2210,7 +2211,7 @@ def test_inference_itm(self):
self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3))
self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))

@require_torch_gpu
@require_torch_accelerator
@require_torch_fp16
def test_inference_itm_fp16(self):
model_name = "Salesforce/blip2-itm-vit-g"
Expand All @@ -2232,7 +2233,7 @@ def test_inference_itm_fp16(self):
)
self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))

@require_torch_gpu
@require_torch_accelerator
@require_torch_fp16
def test_inference_vision_with_projection_fp16(self):
model_name = "Salesforce/blip2-itm-vit-g"
Expand All @@ -2256,7 +2257,7 @@ def test_inference_vision_with_projection_fp16(self):
]
self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3))

@require_torch_gpu
@require_torch_accelerator
@require_torch_fp16
def test_inference_text_with_projection_fp16(self):
model_name = "Salesforce/blip2-itm-vit-g"
Expand Down
2 changes: 1 addition & 1 deletion tests/models/chameleon/test_modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_batching_equivalence(self):

# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
def test_generate_compile_fullgraph(self):
def test_generate_compile_model_forward(self):
pass


Expand Down
Loading

0 comments on commit bbed9cf

Please sign in to comment.