@@ -248,6 +248,40 @@ def _check_accelerate_version():
248248 )
249249
250250
251+ _MXFP4_SUPPORTED_MODEL_TYPES = {"gpt_oss" }
252+
253+
254+ def _is_mxfp4_model (model_path , trust_remote_code = True ):
255+ """Check if a model is an MXFP4 quantized model supported for direct loading.
256+
257+ Only checks when transformers >= 5.0.0. Returns False immediately for older versions,
258+ adding zero overhead to non-MXFP4 model loading.
259+ """
260+ if version .parse (transformers .__version__ ) < version .parse ("5.0.0" ):
261+ return False
262+ from transformers import AutoConfig
263+
264+ try : # in case of config loading failure for new models
265+ config = AutoConfig .from_pretrained (model_path , trust_remote_code = trust_remote_code )
266+ except :
267+ return False
268+
269+ model_type = getattr (config , "model_type" , "" )
270+ if model_type not in _MXFP4_SUPPORTED_MODEL_TYPES :
271+ return False
272+
273+ quant_config = getattr (config , "quantization_config" , None )
274+ if quant_config is None :
275+ return False
276+
277+ quant_method = (
278+ quant_config .get ("quant_method" , "" )
279+ if isinstance (quant_config , dict )
280+ else getattr (quant_config , "quant_method" , "" )
281+ )
282+ return quant_method == "mxfp4" and model_type in _MXFP4_SUPPORTED_MODEL_TYPES
283+
284+
251285def llm_load_model (
252286 pretrained_model_name_or_path : str ,
253287 platform : str = "hf" ,
@@ -284,6 +318,18 @@ def llm_load_model(
284318 if device_str is not None and "hpu" in device_str :
285319 torch_dtype = torch .bfloat16
286320
321+ is_mxfp4 = _is_mxfp4_model (pretrained_model_name_or_path , trust_remote_code = trust_remote_code )
322+ load_kwargs = {
323+ "torch_dtype" : torch_dtype ,
324+ "trust_remote_code" : trust_remote_code ,
325+ "device_map" : "auto" if use_auto_mapping else None ,
326+ }
327+ if is_mxfp4 :
328+ from transformers import Mxfp4Config
329+
330+ load_kwargs ["quantization_config" ] = Mxfp4Config (dequantized = True )
331+ logger .info ("Detected MXFP4 quantized model, using Mxfp4Config(dequantized=True) for loading." )
332+
287333 is_glm = bool (re .search ("chatglm" , pretrained_model_name_or_path .lower ()))
288334
289335 tokenizer = AutoTokenizer .from_pretrained (pretrained_model_name_or_path , trust_remote_code = trust_remote_code )
@@ -295,40 +341,22 @@ def llm_load_model(
295341 if is_hpex_available ():
296342 # For loading FP8 model on HPU
297343 with fake_cuda_for_hpu (), fake_triton_for_hpu (), override_cuda_device_capability ():
298- model = model_cls .from_pretrained (
299- pretrained_model_name_or_path ,
300- torch_dtype = torch_dtype ,
301- trust_remote_code = trust_remote_code ,
302- device_map = "auto" if use_auto_mapping else None ,
303- )
344+ model = model_cls .from_pretrained (pretrained_model_name_or_path , ** load_kwargs )
304345 else :
305346 try :
306- model = model_cls .from_pretrained (
307- pretrained_model_name_or_path ,
308- torch_dtype = torch_dtype ,
309- trust_remote_code = trust_remote_code ,
310- device_map = "auto" if use_auto_mapping else None ,
311- )
347+ model = model_cls .from_pretrained (pretrained_model_name_or_path , ** load_kwargs )
312348 except ValueError as e :
313349 if "FP8 quantized" in str (e ):
314350 with override_cuda_device_capability ():
315- model = model_cls .from_pretrained (
316- pretrained_model_name_or_path ,
317- torch_dtype = torch_dtype ,
318- trust_remote_code = trust_remote_code ,
319- device_map = "auto" if use_auto_mapping else None ,
320- )
351+ model = model_cls .from_pretrained (pretrained_model_name_or_path , ** load_kwargs )
321352 logger .warning ("the support for fp8 model as input is experimental, please use with caution." )
322353 else :
323354 raise
324355
325356 except OSError as e :
326357 logger .warning (f"fail to load { pretrained_model_name_or_path } , set trust_remote_code to False and retry." )
327358 model = model_cls .from_pretrained (
328- pretrained_model_name_or_path ,
329- torch_dtype = torch_dtype ,
330- trust_remote_code = False ,
331- device_map = "auto" if use_auto_mapping else None ,
359+ pretrained_model_name_or_path , ** {** load_kwargs , "trust_remote_code" : False }
332360 )
333361
334362 model = model .eval ()
0 commit comments