Skip to content

Commit 6e459ed

Browse files
authored
Add handling for Qwen3VLMoe on older transformers versions (#2040)
SUMMARY: #1981 added Qwen3VLMoe with associated tests, however this model isn't available on all transformers versions that we support. Therefore, (similar to #2030) this pr ensures we don't import or test the model when using a transformers version that doesn't support it. TEST PLAN: Confirmed that this change fixes `import llmcompressor` when using oldest support transformers version `4.54.0`. Ran test with old transformers version (test gets skipped) and new transformers version (test passes). --------- Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 148efda commit 6e459ed

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

src/llmcompressor/modeling/qwen3_vl_moe.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
from typing import TYPE_CHECKING
2+
13
import torch
2-
from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
3-
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
4-
Qwen3VLMoeTextSparseMoeBlock as OriginalQwen3VLMoeTextSparseMoeBlock,
5-
)
64

75
from llmcompressor.modeling.moe_context import MoECalibrationModule
86
from llmcompressor.utils.dev import skip_weights_initialize
97

8+
if TYPE_CHECKING:
9+
from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
10+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
11+
Qwen3VLMoeTextSparseMoeBlock,
12+
)
13+
1014

1115
@MoECalibrationModule.register("Qwen3VLMoeTextSparseMoeBlock")
1216
class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
@@ -19,12 +23,12 @@ class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
1923

2024
def __init__(
2125
self,
22-
original: OriginalQwen3VLMoeTextSparseMoeBlock,
23-
config: Qwen3VLMoeConfig,
26+
original: "Qwen3VLMoeTextSparseMoeBlock",
27+
config: "Qwen3VLMoeConfig",
2428
calibrate_all_experts: bool,
2529
):
2630
super().__init__()
27-
text_config: Qwen3VLMoeTextConfig = config.get_text_config()
31+
text_config: "Qwen3VLMoeTextConfig" = config.get_text_config()
2832

2933
self.hidden_size = text_config.hidden_size
3034
self.num_experts = text_config.num_experts
@@ -115,8 +119,8 @@ def __init__(self, config, original):
115119

116120

117121
def replace(
118-
config: Qwen3VLMoeConfig,
119-
original: OriginalQwen3VLMoeTextSparseMoeBlock,
122+
config: "Qwen3VLMoeConfig",
123+
original: "Qwen3VLMoeTextSparseMoeBlock",
120124
calibrate_all_experts: bool,
121125
):
122126
return CalibrateQwen3VLMoeTextSparseMoeBlock(

tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
import pytest
12
import torch
2-
from transformers import Qwen3VLMoeConfig
3-
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
4-
Qwen3VLMoeTextSparseMoeBlock,
5-
)
63

74
from llmcompressor.modeling.qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock
85
from llmcompressor.utils.helpers import calibration_forward_context
96
from tests.testing_utils import requires_gpu
107

8+
Qwen3VLMoeConfig = pytest.importorskip(
9+
"transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe",
10+
reason="Qwen3VLMoeConfig not available in this version of transformers",
11+
).Qwen3VLMoeConfig
12+
Qwen3VLMoeTextSparseMoeBlock = pytest.importorskip(
13+
"transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe",
14+
reason="Qwen3VLMoeTextSparseMoeBlock not available in this version of transformers",
15+
).Qwen3VLMoeTextSparseMoeBlock
16+
1117

1218
@requires_gpu
1319
def test_calib_qwen3_vl_moe_module():

0 commit comments

Comments
 (0)