13
13
# limitations under the License.
14
14
"""Testing suite for the PyTorch Qwen3MoE model."""
15
15
16
- import gc
17
16
import unittest
18
17
19
18
import pytest
20
19
21
20
from transformers import AutoTokenizer , Qwen3MoeConfig , is_torch_available , set_seed
22
21
from transformers .testing_utils import (
23
- backend_empty_cache ,
22
+ cleanup ,
24
23
require_bitsandbytes ,
25
24
require_flash_attn ,
26
25
require_torch ,
27
26
require_torch_gpu ,
27
+ require_torch_large_accelerator ,
28
+ require_torch_multi_accelerator ,
28
29
require_torch_sdpa ,
29
30
slow ,
30
31
torch_device ,
@@ -143,45 +144,61 @@ def test_load_balancing_loss(self):
143
144
self .assertNotAlmostEqual (include_padding_result .aux_loss .item (), result .aux_loss .item ())
144
145
145
146
147
+ # Run on runners with larger accelerators (for example A10 instead of T4) with a lot of CPU RAM (e.g. g5-12xlarge)
148
+ @require_torch_multi_accelerator
149
+ @require_torch_large_accelerator
146
150
@require_torch
147
151
class Qwen3MoeIntegrationTest (unittest .TestCase ):
152
+ @classmethod
153
+ def setUpClass (cls ):
154
+ cls .model = None
155
+
156
+ @classmethod
157
+ def tearDownClass (cls ):
158
+ del cls .model
159
+ cleanup (torch_device , gc_collect = True )
160
+
161
+ def tearDown (self ):
162
+ cleanup (torch_device , gc_collect = True )
163
+
164
+ @classmethod
165
+ def get_model (cls ):
166
+ if cls .model is None :
167
+ cls .model = Qwen3MoeForCausalLM .from_pretrained (
168
+ "Qwen/Qwen3-30B-A3B-Base" , device_map = "auto" , load_in_4bit = True
169
+ )
170
+
171
+ return cls .model
172
+
148
173
@slow
149
174
def test_model_15b_a2b_logits (self ):
150
175
input_ids = [1 , 306 , 4658 , 278 , 6593 , 310 , 2834 , 338 ]
151
- model = Qwen3MoeForCausalLM . from_pretrained ( "Qwen/Qwen3-15B-A2B-Base" , device_map = "auto" )
176
+ model = self . get_model ( )
152
177
input_ids = torch .tensor ([input_ids ]).to (model .model .embed_tokens .weight .device )
153
178
with torch .no_grad ():
154
179
out = model (input_ids ).logits .float ().cpu ()
180
+
155
181
# Expected mean on dim = -1
156
- EXPECTED_MEAN = torch .tensor ([[- 1.1184 , 1.1356 , 9.2112 , 8.0254 , 5.1663 , 7.9287 , 8.9245 , 10.0671 ]])
182
+ EXPECTED_MEAN = torch .tensor ([[0.3244 , 0.4406 , 9.0972 , 7.3597 , 4.9985 , 8.0314 , 8.2148 , 9.2134 ]])
157
183
torch .testing .assert_close (out .mean (- 1 ), EXPECTED_MEAN , rtol = 1e-2 , atol = 1e-2 )
184
+
158
185
# slicing logits[0, 0, 0:30]
159
- EXPECTED_SLICE = torch .tensor ([7.5938 , 2.6094 , 4.0312 , 4.0938 , 2.5156 , 2.7812 , 2.9688 , 1.5547 , 1.3984 , 2.2344 , 3.0156 , 3.1562 , 1.1953 , 3.2500 , 1.0938 , 8.4375 , 9.5625 , 9.0625 , 7.5625 , 7.5625 , 7.9062 , 7.2188 , 7.0312 , 6.9375 , 8.0625 , 1.7266 , 0.9141 , 3.7969 , 5.3438 , 3.9844 ]) # fmt: skip
186
+ EXPECTED_SLICE = torch .tensor ([6.8984 , 4.8633 , 4.7734 , 4.5898 , 2.5664 , 2.9902 , 4.8828 , 5.9414 , 4.6250 , 3.0840 , 5.1602 , 6.0117 , 4.9453 , 5.3008 , 3.3145 , 11.3906 , 12.8359 , 12.4844 , 11.2891 , 11.0547 , 11.0391 , 10.3359 , 10.3438 , 10.2578 , 10.7969 , 5.9688 , 3.7676 , 5.5938 , 5.3633 , 5.8203 ]) # fmt: skip
160
187
torch .testing .assert_close (out [0 , 0 , :30 ], EXPECTED_SLICE , rtol = 1e-4 , atol = 1e-4 )
161
188
162
- del model
163
- backend_empty_cache (torch_device )
164
- gc .collect ()
165
-
166
189
@slow
167
190
def test_model_15b_a2b_generation (self ):
168
- EXPECTED_TEXT_COMPLETION = (
169
- """To be or not to be, that is the question. Whether 'tis nobler in the mind to suffer the sl"""
170
- )
191
+ EXPECTED_TEXT_COMPLETION = "To be or not to be: the role of the cell cycle in the regulation of apoptosis.\n The cell cycle is a highly"
171
192
prompt = "To be or not to"
172
- tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen3-15B-A2B -Base" , use_fast = False )
173
- model = Qwen3MoeForCausalLM . from_pretrained ( "Qwen/Qwen3-15B-A2B-Base" , device_map = "auto" )
193
+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen3-30B-A3B -Base" , use_fast = False )
194
+ model = self . get_model ( )
174
195
input_ids = tokenizer .encode (prompt , return_tensors = "pt" ).to (model .model .embed_tokens .weight .device )
175
196
176
197
# greedy generation outputs
177
198
generated_ids = model .generate (input_ids , max_new_tokens = 20 , temperature = 0 )
178
199
text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
179
200
self .assertEqual (EXPECTED_TEXT_COMPLETION , text )
180
201
181
- del model
182
- backend_empty_cache (torch_device )
183
- gc .collect ()
184
-
185
202
@require_bitsandbytes
186
203
@slow
187
204
@require_flash_attn
@@ -191,7 +208,7 @@ def test_model_15b_a2b_long_prompt(self):
191
208
# An input with 4097 tokens that is above the size of the sliding window
192
209
input_ids = [1 ] + [306 , 338 ] * 2048
193
210
model = Qwen3MoeForCausalLM .from_pretrained (
194
- "Qwen/Qwen3-15B-A2B -Base" ,
211
+ "Qwen/Qwen3-30B-A3B -Base" ,
195
212
device_map = "auto" ,
196
213
load_in_4bit = True ,
197
214
attn_implementation = "flash_attention_2" ,
@@ -200,50 +217,20 @@ def test_model_15b_a2b_long_prompt(self):
200
217
generated_ids = model .generate (input_ids , max_new_tokens = 4 , temperature = 0 )
201
218
self .assertEqual (EXPECTED_OUTPUT_TOKEN_IDS , generated_ids [0 ][- 2 :].tolist ())
202
219
203
- # Assisted generation
204
- assistant_model = model
205
- assistant_model .generation_config .num_assistant_tokens = 2
206
- assistant_model .generation_config .num_assistant_tokens_schedule = "constant"
207
- generated_ids = model .generate (input_ids , max_new_tokens = 4 , temperature = 0 )
208
- self .assertEqual (EXPECTED_OUTPUT_TOKEN_IDS , generated_ids [0 ][- 2 :].tolist ())
209
-
210
- del assistant_model
211
- del model
212
- backend_empty_cache (torch_device )
213
- gc .collect ()
214
-
215
220
@slow
216
221
@require_torch_sdpa
217
222
def test_model_15b_a2b_long_prompt_sdpa (self ):
218
223
EXPECTED_OUTPUT_TOKEN_IDS = [306 , 338 ]
219
224
# An input with 4097 tokens that is above the size of the sliding window
220
225
input_ids = [1 ] + [306 , 338 ] * 2048
221
- model = Qwen3MoeForCausalLM .from_pretrained (
222
- "Qwen/Qwen3-15B-A2B-Base" ,
223
- device_map = "auto" ,
224
- attn_implementation = "sdpa" ,
225
- )
226
+ model = self .get_model ()
226
227
input_ids = torch .tensor ([input_ids ]).to (model .model .embed_tokens .weight .device )
227
228
generated_ids = model .generate (input_ids , max_new_tokens = 4 , temperature = 0 )
228
229
self .assertEqual (EXPECTED_OUTPUT_TOKEN_IDS , generated_ids [0 ][- 2 :].tolist ())
229
230
230
- # Assisted generation
231
- assistant_model = model
232
- assistant_model .generation_config .num_assistant_tokens = 2
233
- assistant_model .generation_config .num_assistant_tokens_schedule = "constant"
234
- generated_ids = assistant_model .generate (input_ids , max_new_tokens = 4 , temperature = 0 )
235
- self .assertEqual (EXPECTED_OUTPUT_TOKEN_IDS , generated_ids [0 ][- 2 :].tolist ())
236
-
237
- del assistant_model
238
-
239
- backend_empty_cache (torch_device )
240
- gc .collect ()
241
-
242
- EXPECTED_TEXT_COMPLETION = (
243
- """To be or not to be, that is the question. Whether 'tis nobler in the mind to suffer the sl"""
244
- )
231
+ EXPECTED_TEXT_COMPLETION = "To be or not to be: the role of the cell cycle in the regulation of apoptosis.\n The cell cycle is a highly"
245
232
prompt = "To be or not to"
246
- tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen3-15B-A2B -Base" , use_fast = False )
233
+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen3-30B-A3B -Base" , use_fast = False )
247
234
248
235
input_ids = tokenizer .encode (prompt , return_tensors = "pt" ).to (model .model .embed_tokens .weight .device )
249
236
@@ -255,16 +242,12 @@ def test_model_15b_a2b_long_prompt_sdpa(self):
255
242
@slow
256
243
def test_speculative_generation (self ):
257
244
EXPECTED_TEXT_COMPLETION = (
258
- "To be or not to be, that is the question: whether 'tis nobler in the mind to suffer the sl "
245
+ "To be or not to be: the role of the liver in the pathogenesis of obesity and type 2 diabetes. \n The "
259
246
)
260
247
prompt = "To be or not to"
261
- tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen3-15B-A2B-Base" , use_fast = False )
262
- model = Qwen3MoeForCausalLM .from_pretrained (
263
- "Qwen/Qwen3-15B-A2B-Base" , device_map = "auto" , torch_dtype = torch .float16
264
- )
265
- assistant_model = Qwen3MoeForCausalLM .from_pretrained (
266
- "Qwen/Qwen3-15B-A2B-Base" , device_map = "auto" , torch_dtype = torch .float16
267
- )
248
+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen3-30B-A3B-Base" , use_fast = False )
249
+ model = self .get_model ()
250
+ assistant_model = model
268
251
input_ids = tokenizer .encode (prompt , return_tensors = "pt" ).to (model .model .embed_tokens .weight .device )
269
252
270
253
# greedy generation outputs
@@ -274,7 +257,3 @@ def test_speculative_generation(self):
274
257
)
275
258
text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
276
259
self .assertEqual (EXPECTED_TEXT_COMPLETION , text )
277
-
278
- del model
279
- backend_empty_cache (torch_device )
280
- gc .collect ()
0 commit comments