Skip to content

Commit e14e7db

Browse files
support gpt-oss mxfp4 directly loading (#1401)
Signed-off-by: Xin He <xin3.he@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent eec1bed commit e14e7db

File tree

6 files changed

+78
-34
lines changed

6 files changed

+78
-34
lines changed

auto_round/modeling/fused_moe/moe_experts_interface.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,9 @@ def _unfuse_single_projection(
458458
return linears
459459

460460

461+
_logged_memory_before_replacement = False # To ensure we only log memory once before replacements
462+
463+
461464
def _unfuse_experts_weights_inplace(
462465
module: nn.Module,
463466
check_decorator: bool = True,
@@ -494,8 +497,12 @@ def _unfuse_experts_weights_inplace(
494497
logger.debug(f"Skipping unfuse for {module.__class__.__name__}: does not support @use_experts_implementation")
495498
return False
496499

497-
memory_monitor.update()
498-
memory_monitor.log_summary("Before applying custom replacements")
500+
global _logged_memory_before_replacement
501+
if not _logged_memory_before_replacement:
502+
_logged_memory_before_replacement = True
503+
memory_monitor.update()
504+
memory_monitor.log_summary("Before applying custom replacements")
505+
499506
# Get first projection to determine num_experts and layout
500507
first_proj_name = next(iter(detected_projections))
501508
first_param = getattr(module, first_proj_name)
@@ -594,9 +601,6 @@ def _unfuse_experts_weights_inplace(
594601
# Install compact repr to collapse identical expert containers in print output
595602
_install_compact_expert_repr(module)
596603

597-
memory_monitor.update()
598-
memory_monitor.log_summary("After applying custom replacements")
599-
600604
return True
601605

602606

@@ -645,6 +649,9 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L
645649
model.config._experts_implementation = impl_to_set
646650
logger.debug(f"Set model.config._experts_implementation = '{impl_to_set}'")
647651

652+
memory_monitor.update()
653+
memory_monitor.log_summary("After applying custom replacements")
654+
648655
return unfused_modules
649656

650657

auto_round/modeling/fused_moe/replace_modules.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def _log_first_moe_block(model: torch.nn.Module, label: str) -> None:
101101
for name, module in model.named_modules():
102102
if name.endswith(".experts"):
103103
logger.info(f"Experts ({label}) [{name}] ({module.__class__.__name__}):\n{module}")
104-
return
104+
return True
105+
return False
105106

106107

107108
@dump_mem_usage("Materializing model", log_level="debug")
@@ -289,15 +290,24 @@ def apply_replacements(
289290
The model with modules replaced.
290291
"""
291292
_import_required_replacements(model)
293+
_raw_expert_is_logged = False
292294

293295
# Custom replacements first
294296
if is_custom_model(model):
295-
_log_first_moe_block(model, "before replacement")
297+
298+
if not _raw_expert_is_logged:
299+
_raw_expert_is_logged = _log_first_moe_block(model, "before replacement")
300+
296301
_apply_custom_replacements(model)
297-
_log_first_moe_block(model, "after replacement")
302+
298303
if auto_detect_moe and is_transformers_version_greater_or_equal_5():
299-
_log_first_moe_block(model, "before replacement")
304+
305+
if not _raw_expert_is_logged:
306+
_raw_expert_is_logged = _log_first_moe_block(model, "before replacement")
307+
300308
_handle_moe_modules(model)
309+
310+
if _raw_expert_is_logged:
301311
_log_first_moe_block(model, "after replacement")
302312

303313
return model

auto_round/special_model_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ def get_predefined_ignore_layers(model: torch.nn.Module) -> list[str]:
274274
if not layers and is_moe_model_via_config(config):
275275
for name, _ in model.named_modules():
276276
if name.endswith(".gate"):
277-
print(name)
278277
layers.append(name)
279278

280279
return list(dict.fromkeys(layers))

auto_round/utils/model.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,40 @@ def _check_accelerate_version():
248248
)
249249

250250

251+
_MXFP4_SUPPORTED_MODEL_TYPES = {"gpt_oss"}
252+
253+
254+
def _is_mxfp4_model(model_path, trust_remote_code=True):
255+
"""Check if a model is an MXFP4 quantized model supported for direct loading.
256+
257+
Only checks when transformers >= 5.0.0. Returns False immediately for older versions,
258+
adding zero overhead to non-MXFP4 model loading.
259+
"""
260+
if version.parse(transformers.__version__) < version.parse("5.0.0"):
261+
return False
262+
from transformers import AutoConfig
263+
264+
try: # in case of config loading failure for new models
265+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
266+
except:
267+
return False
268+
269+
model_type = getattr(config, "model_type", "")
270+
if model_type not in _MXFP4_SUPPORTED_MODEL_TYPES:
271+
return False
272+
273+
quant_config = getattr(config, "quantization_config", None)
274+
if quant_config is None:
275+
return False
276+
277+
quant_method = (
278+
quant_config.get("quant_method", "")
279+
if isinstance(quant_config, dict)
280+
else getattr(quant_config, "quant_method", "")
281+
)
282+
return quant_method == "mxfp4" and model_type in _MXFP4_SUPPORTED_MODEL_TYPES
283+
284+
251285
def llm_load_model(
252286
pretrained_model_name_or_path: str,
253287
platform: str = "hf",
@@ -284,6 +318,18 @@ def llm_load_model(
284318
if device_str is not None and "hpu" in device_str:
285319
torch_dtype = torch.bfloat16
286320

321+
is_mxfp4 = _is_mxfp4_model(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
322+
load_kwargs = {
323+
"torch_dtype": torch_dtype,
324+
"trust_remote_code": trust_remote_code,
325+
"device_map": "auto" if use_auto_mapping else None,
326+
}
327+
if is_mxfp4:
328+
from transformers import Mxfp4Config
329+
330+
load_kwargs["quantization_config"] = Mxfp4Config(dequantized=True)
331+
logger.info("Detected MXFP4 quantized model, using Mxfp4Config(dequantized=True) for loading.")
332+
287333
is_glm = bool(re.search("chatglm", pretrained_model_name_or_path.lower()))
288334

289335
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
@@ -295,40 +341,22 @@ def llm_load_model(
295341
if is_hpex_available():
296342
# For loading FP8 model on HPU
297343
with fake_cuda_for_hpu(), fake_triton_for_hpu(), override_cuda_device_capability():
298-
model = model_cls.from_pretrained(
299-
pretrained_model_name_or_path,
300-
torch_dtype=torch_dtype,
301-
trust_remote_code=trust_remote_code,
302-
device_map="auto" if use_auto_mapping else None,
303-
)
344+
model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs)
304345
else:
305346
try:
306-
model = model_cls.from_pretrained(
307-
pretrained_model_name_or_path,
308-
torch_dtype=torch_dtype,
309-
trust_remote_code=trust_remote_code,
310-
device_map="auto" if use_auto_mapping else None,
311-
)
347+
model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs)
312348
except ValueError as e:
313349
if "FP8 quantized" in str(e):
314350
with override_cuda_device_capability():
315-
model = model_cls.from_pretrained(
316-
pretrained_model_name_or_path,
317-
torch_dtype=torch_dtype,
318-
trust_remote_code=trust_remote_code,
319-
device_map="auto" if use_auto_mapping else None,
320-
)
351+
model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs)
321352
logger.warning("the support for fp8 model as input is experimental, please use with caution.")
322353
else:
323354
raise
324355

325356
except OSError as e:
326357
logger.warning(f"fail to load {pretrained_model_name_or_path}, set trust_remote_code to False and retry.")
327358
model = model_cls.from_pretrained(
328-
pretrained_model_name_or_path,
329-
torch_dtype=torch_dtype,
330-
trust_remote_code=False,
331-
device_map="auto" if use_auto_mapping else None,
359+
pretrained_model_name_or_path, **{**load_kwargs, "trust_remote_code": False}
332360
)
333361

334362
model = model.eval()

test/test_cpu/models/test_moe_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ...helpers import get_model_path, transformers_version
1212

13-
gpt_oss_name_or_path = get_model_path("unsloth/gpt-oss-20b-BF16")
13+
gpt_oss_name_or_path = get_model_path("openai/gpt-oss-20b")
1414
llama4_name_or_path = get_model_path("meta-llama/Llama-4-Scout-17B-16E-Instruct")
1515
qwen3_vl_moe_name_or_path = get_model_path("Qwen/Qwen3-VL-30B-A3B-Instruct")
1616
# local path for debug

test/test_cuda/models/test_moe_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
@pytest.fixture
1515
def setup_gpt_oss():
1616
"""Fixture to set up the GPT-OSS model and tokenizer."""
17-
model_name = "/models/gpt-oss-20b-BF16"
17+
model_name = "openai/gpt-oss-20b"
1818
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
1919
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
2020
config.num_hidden_layers = 1 # Reduce layers for testing

0 commit comments

Comments
 (0)