Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
da26896
everything untilo informer
Cyrilvallez Nov 12, 2025
d561c6f
everything until perceiver
Cyrilvallez Nov 12, 2025
ceea305
all of them finally
Cyrilvallez Nov 13, 2025
187bb8e
style
Cyrilvallez Nov 13, 2025
2cd2add
replace by transformers init everywhere
Cyrilvallez Nov 13, 2025
6bdffed
use relative import instead
Cyrilvallez Nov 13, 2025
d25fe72
deprecated models
Cyrilvallez Nov 13, 2025
82899ac
style
Cyrilvallez Nov 13, 2025
a4ab598
start contexts
Cyrilvallez Nov 13, 2025
192151e
small fixes
Cyrilvallez Nov 13, 2025
5efa9a8
fix modular
Cyrilvallez Nov 13, 2025
c882d60
remove class switch
Cyrilvallez Nov 13, 2025
22a55a3
do not initialize tied weights
Cyrilvallez Nov 14, 2025
694440b
typo
Cyrilvallez Nov 14, 2025
5a0174e
fix
Cyrilvallez Nov 14, 2025
5423e06
improve
Cyrilvallez Nov 14, 2025
9b7ace5
improve comments
Cyrilvallez Nov 14, 2025
4acef54
improve
Cyrilvallez Nov 14, 2025
c58d243
improve
Cyrilvallez Nov 14, 2025
2edc8c1
fix zamba
Cyrilvallez Nov 14, 2025
2f40139
fix import
Cyrilvallez Nov 14, 2025
2dd4e00
add the post_init
Cyrilvallez Nov 14, 2025
3ede287
more post_init
Cyrilvallez Nov 14, 2025
86f7169
fix
Cyrilvallez Nov 14, 2025
706799e
protect
Cyrilvallez Nov 14, 2025
1da2d27
more post_init
Cyrilvallez Nov 14, 2025
83e0ada
fix
Cyrilvallez Nov 14, 2025
50187a9
fixes
Cyrilvallez Nov 14, 2025
16173f0
fix
Cyrilvallez Nov 14, 2025
bae372a
fix
Cyrilvallez Nov 14, 2025
8500bcf
switch flag name
Cyrilvallez Nov 14, 2025
cdada86
more fixes
Cyrilvallez Nov 14, 2025
99961fc
fixes
Cyrilvallez Nov 14, 2025
557ef75
fixes
Cyrilvallez Nov 14, 2025
2dd0817
Merge branch 'main' into better-init-2
Cyrilvallez Nov 14, 2025
912440b
copies
Cyrilvallez Nov 14, 2025
acdaf9e
fix
Cyrilvallez Nov 14, 2025
cc10ea4
finally find the culprit
Cyrilvallez Nov 14, 2025
627e77b
style
Cyrilvallez Nov 14, 2025
db42923
last small
Cyrilvallez Nov 14, 2025
17115a2
big bird
Cyrilvallez Nov 14, 2025
bbdc5a5
better
Cyrilvallez Nov 14, 2025
3a12aec
update init check
Cyrilvallez Nov 14, 2025
9beb88c
final touch
Cyrilvallez Nov 14, 2025
6092804
do it everywhere
Cyrilvallez Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
116 changes: 0 additions & 116 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
Expand Down Expand Up @@ -313,120 +312,6 @@ class ConversionEntry:
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4


# Factory function to create LoadedParameter subclasses dynamically
def get_loaded_parameter_class(base_cls):
"""
base_cls: an nn.Parameter subclass (or nn.Parameter) or a Tensor
Returns a new class that combines the base_cls with LoadedParameterMixin

"""

class LoadedParam(base_cls):
_inplace_methods = [
"add_",
"mul_",
"clamp_",
"zero_",
"fill_",
"normal_",
"uniform_",
"copy_",
"erfinv_",
"log_",
"__getitem__",
"neg_",
"exp_",
"sub_",
]

def __new__(cls, from_existing, **kwargs):
if isinstance(from_existing, torch.nn.Parameter):
inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__)
else:
inst = super().__new__(cls, from_existing)
# we store the original object to get it back later on
inst._original = from_existing
# Explicitly override all in-place methods per instance
for method_name in inst._inplace_methods:
setattr(inst, method_name, MethodType(inst._skip, inst))

return inst

def _skip(self, *args, **kwargs):
"""Helper to skip in-place operations."""
return self

def __repr__(self):
return f"LoadedParameter(data={self.data})"

@property
def data(self):
return super().data

@data.setter
def data(self, new):
pass

def __lt__(self, other):
return torch.Tensor.__lt__(self, other)

def __le__(self, other):
return torch.Tensor.__le__(self, other)

def __gt__(self, other):
return torch.Tensor.__gt__(self, other)

def __ge__(self, other):
return torch.Tensor.__ge__(self, other)

def __eq__(self, other):
return torch.Tensor.__eq__(self, other)

def __ne__(self, other):
return torch.Tensor.__ne__(self, other)

def __iadd__(self, *args, **kwargs):
return self

def __isub__(self, *args, **kwargs):
return self

def __imul__(self, *args, **kwargs):
return self

def __imatmul__(self, *args, **kwargs):
return self

def __itruediv__(self, *args, **kwargs):
return self

def __ifloordiv__(self, *args, **kwargs):
return self

def __imod__(self, *args, **kwargs):
return self

def __ipow__(self, *args, **kwargs):
return self

def __iand__(self, *args, **kwargs):
return self

def __ior__(self, *args, **kwargs):
return self

def __ixor__(self, *args, **kwargs):
return self

def __ilshift__(self, *args, **kwargs):
return self

def __irshift__(self, *args, **kwargs):
return self

return LoadedParam


def _materialize_copy(tensor, dtype=None):
tensor = tensor[...]
if dtype is not None:
Expand Down Expand Up @@ -527,7 +412,6 @@ def set_param_for_module(
param_value = param_value.to_local()
if param_name not in module_obj._buffers:
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value)

# Remove from missing keys (it's either mismatched, or all good)
missing_keys.discard(layer_name)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/generation/watermarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch import nn
from torch.nn import BCELoss

from .. import initialization as init
from ..modeling_utils import PreTrainedModel
from ..utils import ModelOutput, logging
from .configuration_utils import PreTrainedConfig, WatermarkingConfig
Expand Down Expand Up @@ -387,7 +388,7 @@ def __init__(self, config):
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Parameter):
module.weight.normal_(mean=0.0, std=0.02)
init.normal_(module.weight, mean=0.0, std=0.02)

def _compute_posterior(
self,
Expand Down
191 changes: 191 additions & 0 deletions src/transformers/initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections import defaultdict
from contextlib import contextmanager

import torch


# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
# in context managers
TORCH_INIT_FUNCTIONS = {
"uniform_": torch.nn.init.uniform_,
"normal_": torch.nn.init.normal_,
"constant_": torch.nn.init.constant_,
"ones_": torch.nn.init.ones_,
"zeros_": torch.nn.init.zeros_,
"eye_": torch.nn.init.eye_,
"dirac_": torch.nn.init.dirac_,
"xavier_uniform_": torch.nn.init.xavier_uniform_,
"xavier_normal_": torch.nn.init.xavier_normal_,
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
"kaiming_normal_": torch.nn.init.kaiming_normal_,
"trunc_normal_": torch.nn.init.trunc_normal_,
"orthogonal_": torch.nn.init.orthogonal_,
"sparse_": torch.nn.init.sparse_,
}


def uniform_(
tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
return tensor


def normal_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
return tensor


def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
return tensor


def ones_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["ones_"](tensor)
return tensor


def zeros_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
return tensor


def eye_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["eye_"](tensor)
return tensor


def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
return tensor


def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
return tensor


def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
return tensor


def kaiming_uniform_(
tensor: torch.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
)
return tensor


def kaiming_normal_(
tensor: torch.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
)
return tensor


def trunc_normal_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
return tensor


def orthogonal_(
tensor: torch.Tensor,
gain: float = 1,
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
return tensor


def sparse_(
tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
return tensor


def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
with torch.no_grad():
return tensor.copy_(other)
return tensor


@contextmanager
def guard_torch_init_functions():
"""
Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.

Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
and for remote code, we also use this context manager.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't work for any tensor manipulation for any remote code / code outside our scope, but its fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I know, this is very unfortunate but we cannot really make it work for remote code 🥲

originals = defaultdict(dict)
try:
# Replace all torch funcs by the ones in this file
for name in TORCH_INIT_FUNCTIONS.keys():
# Here, we need to check all modules imported, and hot patch all of them, as usually torch does
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules,
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
# `setattr(torch.nn.init, name, gloabls()[name])` is thus not enough
for module in sys.modules.values():
if module and hasattr(module, name):
originals[module][name] = getattr(module, name)
setattr(module, name, globals()[name])
yield
finally:
# Set back the original functions on all modules
for module, functions in originals.items():
for name, func in functions.items():
setattr(module, name, func)
Loading