Skip to content

Commit 2507169

Browse files
authored
Fix qwen3 tests (#38862)
* fix * update * update * update * update * update * update * format --------- Co-authored-by: ydshieh <[email protected]>
1 parent 41e0c92 commit 2507169

File tree

1 file changed

+25
-28
lines changed

1 file changed

+25
-28
lines changed

tests/models/qwen3/test_modeling_qwen3.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
"""Testing suite for the PyTorch Qwen3 model."""
1515

16-
import gc
1716
import unittest
1817

1918
import pytest
@@ -23,7 +22,7 @@
2322
from transformers.generation.configuration_utils import GenerationConfig
2423
from transformers.testing_utils import (
2524
Expectations,
26-
backend_empty_cache,
25+
cleanup,
2726
require_bitsandbytes,
2827
require_flash_attn,
2928
require_torch,
@@ -109,6 +108,12 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
109108

110109
@require_torch
111110
class Qwen3IntegrationTest(unittest.TestCase):
111+
def setUp(self):
112+
cleanup(torch_device, gc_collect=True)
113+
114+
def tearDown(self):
115+
cleanup(torch_device, gc_collect=True)
116+
112117
@slow
113118
def test_model_600m_logits(self):
114119
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
@@ -117,15 +122,12 @@ def test_model_600m_logits(self):
117122
with torch.no_grad():
118123
out = model(input_ids).logits.float().cpu()
119124
# Expected mean on dim = -1
120-
EXPECTED_MEAN = torch.tensor([[-1.4577, 1.3261, 3.8498, 3.4229, 2.9009, 1.8813, 2.1530, 2.1431]])
121-
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
125+
EXPECTED_MEAN = torch.tensor([[-1.3789, 1.3029, 3.8262, 3.4637, 2.8796, 1.8357, 2.1290, 2.1814]])
126+
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-4, atol=1e-4)
122127
# slicing logits[0, 0, 0:30]
123-
EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip
124-
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
128+
EXPECTED_SLICE = torch.tensor([4.6905, 4.9243, 4.7101, 3.2052, 2.2683, 1.6576, 3.6529, 3.9800, 3.2605, 2.6475, 3.0468, 4.2296, 5.7443, 4.8940, 4.4883, 6.0323, 7.4057, 7.3710, 6.8373, 6.6323, 6.7114, 6.3069, 6.1751, 6.0416, 6.0793, 4.6975, 2.3286, 3.6387, 2.0757, 1.9813]) # fmt: skip
125129

126-
del model
127-
backend_empty_cache(torch_device)
128-
gc.collect()
130+
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
129131

130132
@slow
131133
def test_model_600m_generation(self):
@@ -140,10 +142,6 @@ def test_model_600m_generation(self):
140142
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
141143
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
142144

143-
del model
144-
backend_empty_cache(torch_device)
145-
gc.collect()
146-
147145
@require_bitsandbytes
148146
@slow
149147
@require_flash_attn
@@ -169,33 +167,29 @@ def test_model_600m_long_prompt(self):
169167
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
170168
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
171169

172-
del assistant_model
173-
del model
174-
backend_empty_cache(torch_device)
175-
gc.collect()
176-
177170
@slow
178171
@require_torch_sdpa
179172
def test_model_600m_long_prompt_sdpa(self):
180-
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
173+
EXPECTED_OUTPUT_TOKEN_IDS = [198, 198]
181174
# An input with 4097 tokens that is above the size of the sliding window
182175
input_ids = [1] + [306, 338] * 2048
183176
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", attn_implementation="sdpa")
184177
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
185178
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
179+
186180
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
187181

188182
# Assisted generation
189183
assistant_model = model
190184
assistant_model.generation_config.num_assistant_tokens = 2
191185
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
192186
generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0)
187+
193188
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
194189

195190
del assistant_model
196191

197-
backend_empty_cache(torch_device)
198-
gc.collect()
192+
cleanup(torch_device, gc_collect=True)
199193

200194
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"
201195
prompt = "My favourite condiment is "
@@ -206,13 +200,19 @@ def test_model_600m_long_prompt_sdpa(self):
206200
# greedy generation outputs
207201
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
208202
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
203+
209204
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
210205

211206
@slow
212207
def test_speculative_generation(self):
213-
EXPECTED_TEXT_COMPLETION = (
214-
"My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it"
215-
)
208+
EXPECTED_TEXT_COMPLETIONS = Expectations(
209+
{
210+
("cuda", 7): "My favourite condiment is 100% natural. It's a little spicy and a little sweet, but it's the",
211+
("cuda", 8): "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it",
212+
}
213+
) # fmt: skip
214+
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
215+
216216
prompt = "My favourite condiment is "
217217
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base", use_fast=False)
218218
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", torch_dtype=torch.float16)
@@ -227,11 +227,8 @@ def test_speculative_generation(self):
227227
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model
228228
)
229229
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
230-
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
231230

232-
del model
233-
backend_empty_cache(torch_device)
234-
gc.collect()
231+
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
235232

236233
@slow
237234
def test_export_static_cache(self):

0 commit comments

Comments
 (0)