Skip to content

Conversation

@liangel-02
Copy link
Contributor

Context

This PR is a followup to #40735 and #41138. Previously, we enabled safetensors in torchao for one shard file. This PR fixes some errors introduced in #41138 and handles the case when checkpoints are sharded onto more than one file, including the edge case where a single quantized tensor (ie Float8Tensor) is sharded onto two different files (ie qdata on one and scale on another).

Summary

If we are loading in a component of a tensor subclass in create_quantized_param() called by _load_state_dict_into_meta_model(), we add this as a new parameter into the model. Then after all parameters are loaded, we unflatten the state_dict and reassign the model parameters.

Testing

Modified unit tests to test all tensor subclasses
python tests/quantization/torchao_integration/test_torchao.py -k TorchAoSafeSerializationTest

@liangel-02 liangel-02 marked this pull request as draft November 3, 2025 17:43
@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch from 8b6b802 to eeb8451 Compare November 3, 2025 17:54
@liangel-02 liangel-02 marked this pull request as ready for review November 3, 2025 18:32
@github-actions github-actions bot requested review from MekkCyber and SunMarc November 3, 2025 18:33
@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch 2 times, most recently from a431b9a to 5a62843 Compare November 3, 2025 21:12
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looks good mostly, had one more inline comment

@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch from 5a62843 to 1a020ed Compare November 3, 2025 21:19
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work ! Left a couple of comments. Btw, we will soon refactor how quantization is applied as we move to dynamic weights loading like vllm. This should help getting support for features like TP

Comment on lines 245 to 268
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata):
updated_state_dict = unflatten_tensor_state_dict(model.state_dict(), metadata)

weights_to_register = set(updated_state_dict.keys())

for name, param in list(model.named_parameters()):
module_fqn, weight_name = name.rsplit(".", 1)
module = model.get_submodule(module_fqn)
weight = getattr(module, weight_name)

device = weight.device
requires_grad = weight.requires_grad

if "_weight_" in weight_name:
delattr(module, weight_name)

if name in weights_to_register:
new_param_value = updated_state_dict[name]
new_param = torch.nn.Parameter(new_param_value.to(device), requires_grad=requires_grad)
module.register_parameter(weight_name, new_param)

weights_to_register.remove(name)

model.load_state_dict(updated_state_dict, strict=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so instead of performing unflatten_tensor_state_dict in create_quantized_param, we do it here at the very end and we just store the flattened weights in the module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we don't want to do it in create_quantized_param since at most, we'd only have access to one shard file, and we want to consider the case where tensor subclass attributes are split up over multiple files

we call unflatten_tensor_state_dict at the very end to get the recovered state dict, and then iterate through the model and replace the weights that represent the tensor attributes with the entire tensor subclass.

@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch from 1a020ed to 7cdb0c6 Compare November 4, 2025 15:58
@liangel-02 liangel-02 requested a review from SunMarc November 4, 2025 15:59
@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch 6 times, most recently from 9a309b0 to f1369bd Compare November 5, 2025 23:27
@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch 3 times, most recently from 5ed0aad to 9219962 Compare November 13, 2025 21:23
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: torchao_integration

@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch from 9219962 to 1b19193 Compare November 13, 2025 21:26
@liangel-02
Copy link
Contributor Author

@SunMarc i rebased my pr and now am seeing this error due to #41580

File "/home/liangel/local/transformers/src/transformers/modeling_utils.py", line 4122, in from_pretrained
    model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
                                                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liangel/local/transformers/src/transformers/modeling_utils.py", line 4275, in _load_pretrained_model
    missing_keys, unexpected_keys, mismatched_keys, misc = convert_and_load_state_dict_in_model(
                                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liangel/local/transformers/src/transformers/core_model_loading.py", line 621, in convert_and_load_state_dict_in_model
    raise ValueError("This quantization method is gonna be supported SOOOON")
ValueError: This quantization method is gonna be supported SOOOON

will there be follow up changes to support torchao/what changes would be needed? cc @jerryzh168

@ArthurZucker
Copy link
Collaborator

Happy to help if you want to do the changes here, I think @SunMarc and @MekkCyber are gonna be helping as well making sure torchao is supported!

@SunMarc
Copy link
Member

SunMarc commented Nov 14, 2025

We just merged a big PR and all quantization methods are impacted, we will add back the support for those methods asap !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants