Skip to content

Commit f17efde

Browse files
frozenleavesfrozenleaves
andauthored
[v1] support automatic discovery of registered kernels. (#9509)
Co-authored-by: frozenleaves <[email protected]>
1 parent 591fc9e commit f17efde

File tree

5 files changed

+228
-36
lines changed

5 files changed

+228
-36
lines changed

src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .....extras.types import HFModel
2121
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
2222
from ..constants import DeviceType, KernelType
23-
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
23+
from ..registry import MetaSwiGluKernel
2424

2525

2626
def _npu_swiglu_forward(self, hidden_state):
@@ -31,25 +31,85 @@ def _npu_swiglu_forward(self, hidden_state):
3131
)
3232

3333

34+
def _npu_swiglu_glm4_forward(self, hidden_states):
35+
import torch_npu
36+
37+
up_states = self.gate_up_proj(hidden_states)
38+
gate, up_states = up_states.chunk(2, dim=-1)
39+
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
40+
41+
42+
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
43+
import torch_npu
44+
45+
gate_proj = self.gate_proj(hidden_states)
46+
if self.activation_sparsity > 0.0:
47+
gate_proj = self._gaussian_topk(gate_proj)
48+
down_proj = self.down_proj(
49+
torch_npu.npu_swiglu(torch.cat((gate_proj, self.up_proj(hidden_states)), dim=-1), dim=-1)
50+
)
51+
return down_proj
52+
53+
3454
class NpuSwiGluKernel(MetaSwiGluKernel):
55+
type = KernelType.SWIGLU
3556
device = DeviceType.NPU
3657
kernel = _npu_swiglu_forward
3758

38-
@classmethod
39-
def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU):
40-
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
59+
# Don't apply the kernel to the following modules
60+
expect_modules = frozenset(
61+
{
62+
"Qwen3VLMoeTextMLP",
63+
"Qwen3VLTextMLP",
64+
"Qwen3OmniMoeThinkerTextMLP",
65+
"Qwen3OmniMoeMLP",
66+
"Qwen3OmniMoeTalkerTextMLP",
67+
"Qwen3OmniMoeCode2WavMlp",
68+
"Qwen3NextMLP",
69+
"Qwen3MoeMLP",
70+
"Qwen3MLP",
71+
"Qwen2MLP",
72+
"Qwen2MoeMLP",
73+
"Qwen2_5_VLMLP",
74+
"Qwen2_5OmniMLP",
75+
"Llama4TextMLP",
76+
"LlamaMLP",
77+
"Glm4MLP",
78+
"Glm4MoeMLP",
79+
"Glm4vMoeTextMLP",
80+
"Gemma3MLP",
81+
"Gemma2MLP",
82+
"Gemma3nTextMLP",
83+
"Phi3MLP",
84+
"DeepseekV2MLP",
85+
"DeepseekV3MLP",
86+
"SeedOssMLP",
87+
}
88+
)
4189

4290
@classmethod
4391
def apply(cls, model, **kwargs) -> "HFModel":
4492
if not is_torch_npu_available():
4593
return model
4694

95+
# Mapping of specific mlp modules to their corresponding kernel implementations
96+
kernel_mapping = {
97+
"Glm4MLP": _npu_swiglu_glm4_forward,
98+
"Glm4vTextMLP": _npu_swiglu_glm4_forward,
99+
"Phi3MLP": _npu_swiglu_glm4_forward,
100+
"Gemma3nTextMLP": _npu_swiglu_gemma3ntext_forward,
101+
}
102+
47103
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
48104
for name, module in model.named_modules():
49-
# Match any module whose class name contains "RMSNorm"
50-
if re.search(swiglu_pattern, module.__class__.__name__):
105+
# Match any module whose class name contains "MLP"
106+
if (
107+
re.search(swiglu_pattern, module.__class__.__name__)
108+
and module.__class__.__name__ in cls.expect_modules
109+
):
51110
# Bind function as an instance method to preserve `self` semantics
52111
# and replace the original forward
53-
module.forward = types.MethodType(cls.kernel, module)
112+
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
113+
module.forward = types.MethodType(kernel_func, module)
54114

55115
return model

src/llamafactory/v1/plugins/model_plugins/kernels/registry.py

Lines changed: 129 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from abc import ABC, abstractmethod
15+
from abc import ABC, ABCMeta, abstractmethod
1616
from typing import Any, Callable, Optional
1717

1818
from ....extras.types import HFModel
@@ -61,18 +61,67 @@ def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Option
6161
KERNEL_REGISTRY = KernelRegistry()
6262

6363

64-
class MetaKernel(ABC):
64+
class AutoRegisterKernelMeta(ABCMeta):
65+
"""Metaclass that automatically registers kernel classes upon creation.
66+
67+
This metaclass checks if a newly created class has both `type` and `device`
68+
attributes defined. If so, it automatically registers the kernel in the
69+
global KERNEL_REGISTRY, eliminating the need for manual registration.
70+
71+
To disable auto-registration for a specific class, set `auto_register = False`.
72+
"""
73+
74+
def __new__(mcs, name, bases, namespace, **kwargs):
75+
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
76+
77+
# Check if auto-registration is disabled
78+
auto_register = namespace.get("auto_register", True)
79+
80+
# Only auto-register if the class has both type and device attributes defined
81+
# and they are not None (skip base classes like MetaKernel itself)
82+
# and auto_register is True
83+
kernel_type = namespace.get("type")
84+
device_type = namespace.get("device")
85+
86+
if auto_register and kernel_type is not None and device_type is not None:
87+
# Auto-register this kernel
88+
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
89+
90+
return cls
91+
92+
93+
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
94+
"""Base class for all kernel implementations.
95+
96+
Subclasses are automatically registered when they define both `type` and `device`
97+
attributes. To disable auto-registration, set `auto_register = False`.
98+
99+
Attributes:
100+
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
101+
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
102+
kernel: The actual kernel function or implementation.
103+
auto_register: Set to False to disable automatic registration (default: True).
104+
"""
105+
65106
type: Optional[KernelType] = None
66107
device: Optional[DeviceType] = None
67108
kernel: Optional[Callable] = None
68109

69-
@classmethod
70-
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
71-
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
72-
73110
@classmethod
74111
@abstractmethod
75112
def apply(cls, model: HFModel, **kwargs) -> HFModel:
113+
"""Apply the kernel to the model.
114+
115+
This method should check if the kernel can be applied (e.g., dependencies
116+
are installed, target modules exist) and perform the kernel replacement.
117+
118+
Args:
119+
model: The HuggingFace model to optimize.
120+
**kwargs: Additional arguments for kernel application.
121+
122+
Returns:
123+
The optimized model (may be the same object with modifications).
124+
"""
76125
raise NotImplementedError
77126

78127

@@ -106,16 +155,75 @@ def apply(cls, model: HFModel, **kwargs) -> HFModel:
106155
raise NotImplementedError
107156

108157

109-
def discover_kernels(model: HFModel) -> list[MetaKernel]:
110-
"""Discover and construct MetaKernel instances for the current model/device.
158+
def _ensure_kernels_loaded() -> None:
159+
"""Ensure all kernel implementations are imported and registered.
111160
112-
This is a placeholder to be implemented: it should inspect the runtime
113-
environment (device type, available extensions, model architecture) and
114-
return an ordered list of MetaKernel instances to be applied. Each returned
115-
MetaKernel must encapsulate its own replacement logic in `apply`.
161+
This function dynamically imports all kernel implementation modules to trigger
162+
their auto-registration. Python's module system ensures each module is only
163+
executed once (cached in sys.modules), so repeated calls are safe and fast.
164+
"""
165+
# List of kernel module paths to import
166+
kernel_modules = [
167+
"rms_norm.npu_rms_norm",
168+
"rope.npu_rope",
169+
"mlp.npu_swiglu",
170+
"mlp.npu_fused_moe",
171+
# Add new kernel modules here as they are created
172+
]
173+
174+
# Import each module to trigger kernel registration
175+
# Python's import system caches modules, so this is fast on subsequent calls
176+
for module_name in kernel_modules:
177+
try:
178+
__import__(f"{__package__}.{module_name}", fromlist=["*"])
179+
except ImportError:
180+
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
181+
pass
182+
183+
184+
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
185+
"""Discover and return all kernel classes registered for the current device.
186+
187+
This function inspects the runtime environment (device type) and returns
188+
all MetaKernel classes registered for that device. Each kernel's `apply()`
189+
method is responsible for checking if it can actually be applied (e.g.,
190+
required dependencies are installed, target modules exist in the model).
191+
192+
The function automatically discovers all kernels registered in KERNEL_REGISTRY
193+
without requiring manual enumeration. On first call, it dynamically imports
194+
all kernel implementation modules to trigger their auto-registration.
195+
196+
Args:
197+
model: The HuggingFace model to apply kernels to.
198+
TODO: implement the kernel route detection logic by model structure.
199+
200+
Returns:
201+
A list of MetaKernel classes available for the current device.
116202
"""
117-
# TODO: Implement auto discovery logic based on registry and device capabilities.
118-
return []
203+
# Ensure all kernel modules are imported to trigger registration
204+
_ensure_kernels_loaded()
205+
206+
discovered_kernels: list[type[MetaKernel]] = []
207+
208+
# Detect current device type
209+
accelerator = get_available_accelerator()
210+
try:
211+
device_type = DeviceType(accelerator.type)
212+
except ValueError:
213+
# Unknown device type, return empty list
214+
return discovered_kernels
215+
216+
# Skip CPU as it typically doesn't have optimized kernels
217+
if device_type == DeviceType.CPU:
218+
return discovered_kernels
219+
220+
# Iterate through registry and collect all kernels for current device
221+
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
222+
kernel_cls = devices.get(device_type)
223+
if kernel_cls is not None:
224+
discovered_kernels.append(kernel_cls)
225+
226+
return discovered_kernels
119227

120228

121229
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
@@ -136,3 +244,10 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
136244
raise ValueError(
137245
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
138246
)
247+
248+
249+
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
250+
"""Apply all available kernels to the model."""
251+
for kernel in discover_kernels(model):
252+
model = apply_kernel(model, kernel, **kwargs)
253+
return model

src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .....extras.types import HFModel
1818
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
1919
from ..constants import DeviceType, KernelType
20-
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
20+
from ..registry import MetaRMSNormKernel
2121

2222

2323
def _npu_rms_forward(self, hidden_states):
@@ -38,14 +38,10 @@ def _npu_rms_forward(self, hidden_states):
3838
class NpuRMSNormKernel(MetaRMSNormKernel):
3939
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
4040

41+
type = KernelType.RMSNORM
4142
device = DeviceType.NPU
4243
kernel = _npu_rms_forward
4344

44-
@classmethod
45-
def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU):
46-
"""Register the NPU RMSNorm forward implementation to the global registry."""
47-
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
48-
4945
@classmethod
5046
def apply(cls, model, **kwargs) -> HFModel:
5147
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.

src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .....extras.types import HFModel
2020
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
2121
from ..constants import DeviceType, KernelType
22-
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
22+
from ..registry import MetaRoPEKernel
2323

2424

2525
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
@@ -51,13 +51,10 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
5151

5252

5353
class NpuRoPEKernel(MetaRoPEKernel):
54+
type = KernelType.ROPE
5455
device = DeviceType.NPU
5556
kernel = _apply_rotary_pos_emb
5657

57-
@classmethod
58-
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
59-
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
60-
6158
@classmethod
6259
def apply(cls, model, **kwargs) -> "HFModel":
6360
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
@@ -88,12 +85,16 @@ def apply(cls, model, **kwargs) -> "HFModel":
8885

8986

9087
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
88+
"""Qwen2-VL specific RoPE kernel - not auto-registered.
89+
90+
This kernel is for specific models (Qwen2-VL) and should be manually
91+
applied when needed rather than auto-discovered.
92+
"""
93+
94+
type = KernelType.ROPE
9195
device = DeviceType.NPU
9296
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
93-
94-
@classmethod
95-
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
96-
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
97+
auto_register = False # Disable auto-registration for model-specific kernel
9798

9899
@classmethod
99100
def apply(cls, model, **kwargs) -> "HFModel":

tests_v1/plugins/model_plugins/test_kernel_plugin.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,23 @@ def test_apply_kernel(self, mock_get_accelerator):
4242

4343
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
4444
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
45+
46+
47+
class Test_Use_V1_Kernels(unittest.TestCase):
48+
@patch("torch.accelerator.current_accelerator")
49+
def test_use_v1_kernels(self, mock_get_accelerator):
50+
mock_device = MagicMock()
51+
mock_device.type = "npu"
52+
mock_get_accelerator.return_value = mock_device
53+
54+
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
55+
56+
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
57+
original_swiglu_forward = model.model.layers[0].mlp.forward
58+
59+
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels
60+
61+
model = apply_available_kernels(model)
62+
63+
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
64+
assert model.model.layers[0].mlp.forward is not original_swiglu_forward

0 commit comments

Comments
 (0)