13
13
# limitations under the License.
14
14
"""Testing suite for the PyTorch Qwen3 model."""
15
15
16
- import gc
17
16
import unittest
18
17
19
18
import pytest
23
22
from transformers .generation .configuration_utils import GenerationConfig
24
23
from transformers .testing_utils import (
25
24
Expectations ,
26
- backend_empty_cache ,
25
+ cleanup ,
27
26
require_bitsandbytes ,
28
27
require_flash_attn ,
29
28
require_torch ,
@@ -109,6 +108,12 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
109
108
110
109
@require_torch
111
110
class Qwen3IntegrationTest (unittest .TestCase ):
111
+ def setUp (self ):
112
+ cleanup (torch_device , gc_collect = True )
113
+
114
+ def tearDown (self ):
115
+ cleanup (torch_device , gc_collect = True )
116
+
112
117
@slow
113
118
def test_model_600m_logits (self ):
114
119
input_ids = [1 , 306 , 4658 , 278 , 6593 , 310 , 2834 , 338 ]
@@ -117,15 +122,12 @@ def test_model_600m_logits(self):
117
122
with torch .no_grad ():
118
123
out = model (input_ids ).logits .float ().cpu ()
119
124
# Expected mean on dim = -1
120
- EXPECTED_MEAN = torch .tensor ([[- 1.4577 , 1.3261 , 3.8498 , 3.4229 , 2.9009 , 1.8813 , 2.1530 , 2.1431 ]])
121
- torch .testing .assert_close (out .mean (- 1 ), EXPECTED_MEAN , rtol = 1e-2 , atol = 1e-2 )
125
+ EXPECTED_MEAN = torch .tensor ([[- 1.3789 , 1.3029 , 3.8262 , 3.4637 , 2.8796 , 1.8357 , 2.1290 , 2.1814 ]])
126
+ torch .testing .assert_close (out .mean (- 1 ), EXPECTED_MEAN , rtol = 1e-4 , atol = 1e-4 )
122
127
# slicing logits[0, 0, 0:30]
123
- EXPECTED_SLICE = torch .tensor ([5.9062 , 6.0938 , 5.5625 , 3.8594 , 2.6094 , 1.9531 , 4.3125 , 4.9375 , 3.8906 , 3.1094 , 3.6719 , 5.1562 , 6.9062 , 5.7500 , 5.4062 , 7.0625 , 8.7500 , 8.7500 , 8.1250 , 7.9375 , 8.0625 , 7.5312 , 7.3750 , 7.2188 , 7.2500 , 5.8750 , 2.8750 , 4.3438 , 2.3438 , 2.2500 ]) # fmt: skip
124
- torch .testing .assert_close (out [0 , 0 , :30 ], EXPECTED_SLICE , rtol = 1e-4 , atol = 1e-4 )
128
+ EXPECTED_SLICE = torch .tensor ([4.6905 , 4.9243 , 4.7101 , 3.2052 , 2.2683 , 1.6576 , 3.6529 , 3.9800 , 3.2605 , 2.6475 , 3.0468 , 4.2296 , 5.7443 , 4.8940 , 4.4883 , 6.0323 , 7.4057 , 7.3710 , 6.8373 , 6.6323 , 6.7114 , 6.3069 , 6.1751 , 6.0416 , 6.0793 , 4.6975 , 2.3286 , 3.6387 , 2.0757 , 1.9813 ]) # fmt: skip
125
129
126
- del model
127
- backend_empty_cache (torch_device )
128
- gc .collect ()
130
+ torch .testing .assert_close (out [0 , 0 , :30 ], EXPECTED_SLICE , rtol = 1e-4 , atol = 1e-4 )
129
131
130
132
@slow
131
133
def test_model_600m_generation (self ):
@@ -140,10 +142,6 @@ def test_model_600m_generation(self):
140
142
text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
141
143
self .assertEqual (EXPECTED_TEXT_COMPLETION , text )
142
144
143
- del model
144
- backend_empty_cache (torch_device )
145
- gc .collect ()
146
-
147
145
@require_bitsandbytes
148
146
@slow
149
147
@require_flash_attn
@@ -169,33 +167,29 @@ def test_model_600m_long_prompt(self):
169
167
generated_ids = model .generate (input_ids , max_new_tokens = 4 , temperature = 0 )
170
168
self .assertEqual (EXPECTED_OUTPUT_TOKEN_IDS , generated_ids [0 ][- 2 :].tolist ())
171
169
172
- del assistant_model
173
- del model
174
- backend_empty_cache (torch_device )
175
- gc .collect ()
176
-
177
170
@slow
178
171
@require_torch_sdpa
179
172
def test_model_600m_long_prompt_sdpa (self ):
180
- EXPECTED_OUTPUT_TOKEN_IDS = [306 , 338 ]
173
+ EXPECTED_OUTPUT_TOKEN_IDS = [198 , 198 ]
181
174
# An input with 4097 tokens that is above the size of the sliding window
182
175
input_ids = [1 ] + [306 , 338 ] * 2048
183
176
model = Qwen3ForCausalLM .from_pretrained ("Qwen/Qwen3-0.6B-Base" , device_map = "auto" , attn_implementation = "sdpa" )
184
177
input_ids = torch .tensor ([input_ids ]).to (model .model .embed_tokens .weight .device )
185
178
generated_ids = model .generate (input_ids , max_new_tokens = 4 , temperature = 0 )
179
+
186
180
self .assertEqual (EXPECTED_OUTPUT_TOKEN_IDS , generated_ids [0 ][- 2 :].tolist ())
187
181
188
182
# Assisted generation
189
183
assistant_model = model
190
184
assistant_model .generation_config .num_assistant_tokens = 2
191
185
assistant_model .generation_config .num_assistant_tokens_schedule = "constant"
192
186
generated_ids = assistant_model .generate (input_ids , max_new_tokens = 4 , temperature = 0 )
187
+
193
188
self .assertEqual (EXPECTED_OUTPUT_TOKEN_IDS , generated_ids [0 ][- 2 :].tolist ())
194
189
195
190
del assistant_model
196
191
197
- backend_empty_cache (torch_device )
198
- gc .collect ()
192
+ cleanup (torch_device , gc_collect = True )
199
193
200
194
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"
201
195
prompt = "My favourite condiment is "
@@ -206,13 +200,19 @@ def test_model_600m_long_prompt_sdpa(self):
206
200
# greedy generation outputs
207
201
generated_ids = model .generate (input_ids , max_new_tokens = 20 , temperature = 0 )
208
202
text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
203
+
209
204
self .assertEqual (EXPECTED_TEXT_COMPLETION , text )
210
205
211
206
@slow
212
207
def test_speculative_generation (self ):
213
- EXPECTED_TEXT_COMPLETION = (
214
- "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it"
215
- )
208
+ EXPECTED_TEXT_COMPLETIONS = Expectations (
209
+ {
210
+ ("cuda" , 7 ): "My favourite condiment is 100% natural. It's a little spicy and a little sweet, but it's the" ,
211
+ ("cuda" , 8 ): "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it" ,
212
+ }
213
+ ) # fmt: skip
214
+ EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS .get_expectation ()
215
+
216
216
prompt = "My favourite condiment is "
217
217
tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen3-0.6B-Base" , use_fast = False )
218
218
model = Qwen3ForCausalLM .from_pretrained ("Qwen/Qwen3-0.6B-Base" , device_map = "auto" , torch_dtype = torch .float16 )
@@ -227,11 +227,8 @@ def test_speculative_generation(self):
227
227
input_ids , max_new_tokens = 20 , do_sample = True , temperature = 0.3 , assistant_model = assistant_model
228
228
)
229
229
text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
230
- self .assertEqual (EXPECTED_TEXT_COMPLETION , text )
231
230
232
- del model
233
- backend_empty_cache (torch_device )
234
- gc .collect ()
231
+ self .assertEqual (EXPECTED_TEXT_COMPLETION , text )
235
232
236
233
@slow
237
234
def test_export_static_cache (self ):
0 commit comments