1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from abc import ABC , abstractmethod
15+ from abc import ABC , ABCMeta , abstractmethod
1616from typing import Any , Callable , Optional
1717
1818from ....extras .types import HFModel
@@ -61,18 +61,67 @@ def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Option
6161KERNEL_REGISTRY = KernelRegistry ()
6262
6363
64- class MetaKernel (ABC ):
64+ class AutoRegisterKernelMeta (ABCMeta ):
65+ """Metaclass that automatically registers kernel classes upon creation.
66+
67+ This metaclass checks if a newly created class has both `type` and `device`
68+ attributes defined. If so, it automatically registers the kernel in the
69+ global KERNEL_REGISTRY, eliminating the need for manual registration.
70+
71+ To disable auto-registration for a specific class, set `auto_register = False`.
72+ """
73+
74+ def __new__ (mcs , name , bases , namespace , ** kwargs ):
75+ cls = super ().__new__ (mcs , name , bases , namespace , ** kwargs )
76+
77+ # Check if auto-registration is disabled
78+ auto_register = namespace .get ("auto_register" , True )
79+
80+ # Only auto-register if the class has both type and device attributes defined
81+ # and they are not None (skip base classes like MetaKernel itself)
82+ # and auto_register is True
83+ kernel_type = namespace .get ("type" )
84+ device_type = namespace .get ("device" )
85+
86+ if auto_register and kernel_type is not None and device_type is not None :
87+ # Auto-register this kernel
88+ KERNEL_REGISTRY .register (kernel_type , device_type , cls )
89+
90+ return cls
91+
92+
93+ class MetaKernel (ABC , metaclass = AutoRegisterKernelMeta ):
94+ """Base class for all kernel implementations.
95+
96+ Subclasses are automatically registered when they define both `type` and `device`
97+ attributes. To disable auto-registration, set `auto_register = False`.
98+
99+ Attributes:
100+ type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
101+ device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
102+ kernel: The actual kernel function or implementation.
103+ auto_register: Set to False to disable automatic registration (default: True).
104+ """
105+
65106 type : Optional [KernelType ] = None
66107 device : Optional [DeviceType ] = None
67108 kernel : Optional [Callable ] = None
68109
69- @classmethod
70- def register_kernel (cls , kernel_type : KernelType , device_type : DeviceType ):
71- KERNEL_REGISTRY .register (kernel_type , device_type , cls )
72-
73110 @classmethod
74111 @abstractmethod
75112 def apply (cls , model : HFModel , ** kwargs ) -> HFModel :
113+ """Apply the kernel to the model.
114+
115+ This method should check if the kernel can be applied (e.g., dependencies
116+ are installed, target modules exist) and perform the kernel replacement.
117+
118+ Args:
119+ model: The HuggingFace model to optimize.
120+ **kwargs: Additional arguments for kernel application.
121+
122+ Returns:
123+ The optimized model (may be the same object with modifications).
124+ """
76125 raise NotImplementedError
77126
78127
@@ -106,16 +155,75 @@ def apply(cls, model: HFModel, **kwargs) -> HFModel:
106155 raise NotImplementedError
107156
108157
109- def discover_kernels ( model : HFModel ) -> list [ MetaKernel ] :
110- """Discover and construct MetaKernel instances for the current model/device .
158+ def _ensure_kernels_loaded ( ) -> None :
159+ """Ensure all kernel implementations are imported and registered .
111160
112- This is a placeholder to be implemented: it should inspect the runtime
113- environment (device type, available extensions, model architecture) and
114- return an ordered list of MetaKernel instances to be applied. Each returned
115- MetaKernel must encapsulate its own replacement logic in `apply`.
161+ This function dynamically imports all kernel implementation modules to trigger
162+ their auto-registration. Python's module system ensures each module is only
163+ executed once (cached in sys.modules), so repeated calls are safe and fast.
164+ """
165+ # List of kernel module paths to import
166+ kernel_modules = [
167+ "rms_norm.npu_rms_norm" ,
168+ "rope.npu_rope" ,
169+ "mlp.npu_swiglu" ,
170+ "mlp.npu_fused_moe" ,
171+ # Add new kernel modules here as they are created
172+ ]
173+
174+ # Import each module to trigger kernel registration
175+ # Python's import system caches modules, so this is fast on subsequent calls
176+ for module_name in kernel_modules :
177+ try :
178+ __import__ (f"{ __package__ } .{ module_name } " , fromlist = ["*" ])
179+ except ImportError :
180+ # Silently ignore import errors (e.g., missing dependencies like torch_npu)
181+ pass
182+
183+
184+ def discover_kernels (model : HFModel = None ) -> list [type [MetaKernel ]]:
185+ """Discover and return all kernel classes registered for the current device.
186+
187+ This function inspects the runtime environment (device type) and returns
188+ all MetaKernel classes registered for that device. Each kernel's `apply()`
189+ method is responsible for checking if it can actually be applied (e.g.,
190+ required dependencies are installed, target modules exist in the model).
191+
192+ The function automatically discovers all kernels registered in KERNEL_REGISTRY
193+ without requiring manual enumeration. On first call, it dynamically imports
194+ all kernel implementation modules to trigger their auto-registration.
195+
196+ Args:
197+ model: The HuggingFace model to apply kernels to.
198+ TODO: implement the kernel route detection logic by model structure.
199+
200+ Returns:
201+ A list of MetaKernel classes available for the current device.
116202 """
117- # TODO: Implement auto discovery logic based on registry and device capabilities.
118- return []
203+ # Ensure all kernel modules are imported to trigger registration
204+ _ensure_kernels_loaded ()
205+
206+ discovered_kernels : list [type [MetaKernel ]] = []
207+
208+ # Detect current device type
209+ accelerator = get_available_accelerator ()
210+ try :
211+ device_type = DeviceType (accelerator .type )
212+ except ValueError :
213+ # Unknown device type, return empty list
214+ return discovered_kernels
215+
216+ # Skip CPU as it typically doesn't have optimized kernels
217+ if device_type == DeviceType .CPU :
218+ return discovered_kernels
219+
220+ # Iterate through registry and collect all kernels for current device
221+ for kernel_type , devices in KERNEL_REGISTRY ._registry .items ():
222+ kernel_cls = devices .get (device_type )
223+ if kernel_cls is not None :
224+ discovered_kernels .append (kernel_cls )
225+
226+ return discovered_kernels
119227
120228
121229def apply_kernel (model : HFModel , kernel : type [MetaKernel ], / , ** kwargs ) -> "HFModel" :
@@ -136,3 +244,10 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
136244 raise ValueError (
137245 f"{ kernel } must be a MetaKernel instance, or the kernel don't match the device type. got { kernel .device } and { get_available_accelerator ().type } instead."
138246 )
247+
248+
249+ def apply_available_kernels (model : HFModel , ** kwargs ) -> "HFModel" :
250+ """Apply all available kernels to the model."""
251+ for kernel in discover_kernels (model ):
252+ model = apply_kernel (model , kernel , ** kwargs )
253+ return model
0 commit comments