Skip to content

Commit 08bf7f1

Browse files
authored
Add kernelize to transformers (#38205)
* fix * fix * fix flow * remove non compiling path * change * style * fix * update * update pin * revert
1 parent be10d4d commit 08bf7f1

File tree

4 files changed

+13
-43
lines changed

4 files changed

+13
-43
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
129129
"keras>2.9,<2.16",
130130
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
131-
"kernels>=0.4.4,<0.5",
131+
"kernels>=0.6.1,<0.7",
132132
"librosa",
133133
"natten>=0.14.6,<0.15.0",
134134
"nltk<=3.8.1",

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"kenlm": "kenlm",
3535
"keras": "keras>2.9,<2.16",
3636
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
37-
"kernels": "kernels>=0.4.4,<0.5",
37+
"kernels": "kernels>=0.6.1,<0.7",
3838
"librosa": "librosa",
3939
"natten": "natten>=0.14.6,<0.15.0",
4040
"nltk": "nltk<=3.8.1",

src/transformers/integrations/hub_kernels.py

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,14 @@
1313
# limitations under the License.
1414
from typing import Union
1515

16-
from ..utils import is_torchdynamo_compiling
17-
1816

1917
try:
2018
from kernels import (
2119
Device,
2220
LayerRepository,
2321
register_kernel_mapping,
2422
replace_kernel_forward_from_hub,
25-
)
26-
from kernels import (
27-
use_kernel_forward_from_hub as original_use_kernel_forward_from_hub,
23+
use_kernel_forward_from_hub,
2824
)
2925

3026
_hub_kernels_available = True
@@ -45,9 +41,9 @@
4541
},
4642
"RMSNorm": {
4743
"cuda": LayerRepository(
48-
repo_id="kernels-community/triton-layer-norm",
49-
layer_name="LlamaRMSNorm",
50-
revision="pure-layer-test",
44+
repo_id="kernels-community/liger_kernels",
45+
layer_name="LigerRMSNorm",
46+
# revision="pure-layer-test",
5147
)
5248
},
5349
"MLP": {
@@ -60,39 +56,6 @@
6056

6157
register_kernel_mapping(_KERNEL_MAPPING)
6258

63-
def use_kernel_forward_from_hub(*args, **kwargs):
64-
"""
65-
Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed
66-
when `kernels` supports `torch.compile`.
67-
68-
If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the
69-
kernel.
70-
"""
71-
72-
def decorator_with_compile_path(cls):
73-
# Keeps a reference to the original forward method
74-
original_forward = cls.forward
75-
76-
# Applies the original decorator
77-
decorator = original_use_kernel_forward_from_hub(*args, **kwargs)
78-
cls = decorator(cls)
79-
80-
# Replaces the kernel forward with a compile-friendly version
81-
kernel_forward = cls.forward
82-
83-
def forward_with_compile_path(*forward_args, **forward_kwargs):
84-
disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None)
85-
if is_torchdynamo_compiling() or disable_custom_kernels:
86-
return original_forward(*forward_args, **forward_kwargs)
87-
else:
88-
return kernel_forward(*forward_args, **forward_kwargs)
89-
90-
cls.forward = forward_with_compile_path
91-
92-
return cls
93-
94-
return decorator_with_compile_path
95-
9659

9760
except ImportError:
9861
# Stub to make decorators int transformers work when `kernels`

src/transformers/modeling_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4281,6 +4281,7 @@ def from_pretrained(
42814281
tp_size = kwargs.pop("tp_size", None)
42824282
device_mesh = kwargs.pop("device_mesh", None)
42834283
trust_remote_code = kwargs.pop("trust_remote_code", None)
4284+
use_kernels = kwargs.pop("use_kernels", False)
42844285

42854286
key_mapping = kwargs.pop("key_mapping", None)
42864287
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
@@ -4733,6 +4734,12 @@ def _assign_original_dtype(module):
47334734
# Set model in evaluation mode to deactivate DropOut modules by default
47344735
model.eval()
47354736

4737+
# check if using kernels
4738+
if use_kernels:
4739+
from kernels import Device, kernelize
4740+
4741+
kernelize(model, device=Device(type=model.device.type))
4742+
47364743
# If it is a model with generation capabilities, attempt to load generation files (generation config,
47374744
# custom generate function)
47384745
if model.can_generate() and generation_config is not None:

0 commit comments

Comments
 (0)