Skip to content

Commit 309e8c9

Browse files
authored
Fix phi4_multimodal tests (#38816)
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <[email protected]>
1 parent 3526e25 commit 309e8c9

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import gc
1615
import tempfile
1716
import unittest
1817

@@ -31,7 +30,15 @@
3130
is_torch_available,
3231
is_vision_available,
3332
)
34-
from transformers.testing_utils import backend_empty_cache, require_soundfile, require_torch, slow, torch_device
33+
from transformers.testing_utils import (
34+
Expectations,
35+
cleanup,
36+
require_soundfile,
37+
require_torch,
38+
require_torch_large_accelerator,
39+
slow,
40+
torch_device,
41+
)
3542
from transformers.utils import is_soundfile_available
3643

3744
from ...generation.test_utils import GenerationTesterMixin
@@ -276,13 +283,14 @@ def test_flex_attention_with_grads(self):
276283
@slow
277284
class Phi4MultimodalIntegrationTest(unittest.TestCase):
278285
checkpoint_path = "microsoft/Phi-4-multimodal-instruct"
286+
revision = "refs/pr/70"
279287
image_url = "https://www.ilankelman.org/stopsigns/australia.jpg"
280288
audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"
281289

282290
def setUp(self):
283291
# Currently, the Phi-4 checkpoint on the hub is not working with the latest Phi-4 code, so the slow integration tests
284292
# won't pass without using the correct revision (refs/pr/70)
285-
self.processor = AutoProcessor.from_pretrained(self.checkpoint_path)
293+
self.processor = AutoProcessor.from_pretrained(self.checkpoint_path, revision=self.revision)
286294
self.generation_config = GenerationConfig(max_new_tokens=20, do_sample=False)
287295
self.user_token = "<|user|>"
288296
self.assistant_token = "<|assistant|>"
@@ -294,13 +302,14 @@ def setUp(self):
294302
tmp.seek(0)
295303
self.audio, self.sampling_rate = soundfile.read(tmp.name)
296304

305+
cleanup(torch_device, gc_collect=True)
306+
297307
def tearDown(self):
298-
gc.collect()
299-
backend_empty_cache(torch_device)
308+
cleanup(torch_device, gc_collect=True)
300309

301310
def test_text_only_generation(self):
302311
model = AutoModelForCausalLM.from_pretrained(
303-
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
312+
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
304313
)
305314

306315
prompt = f"{self.user_token}What is the answer for 1+1? Explain it.{self.end_token}{self.assistant_token}"
@@ -319,7 +328,7 @@ def test_text_only_generation(self):
319328

320329
def test_vision_text_generation(self):
321330
model = AutoModelForCausalLM.from_pretrained(
322-
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
331+
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
323332
)
324333

325334
prompt = f"{self.user_token}<|image|>What is shown in this image?{self.end_token}{self.assistant_token}"
@@ -332,13 +341,20 @@ def test_vision_text_generation(self):
332341
output = output[:, inputs["input_ids"].shape[1] :]
333342
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
334343

335-
EXPECTED_RESPONSE = "The image shows a vibrant scene at a street intersection in a city with a Chinese-influenced architectural"
344+
EXPECTED_RESPONSES = Expectations(
345+
{
346+
("cuda", 7): 'The image shows a vibrant scene at a traditional Chinese-style street entrance, known as a "gate"',
347+
("cuda", 8): 'The image shows a vibrant scene at a street intersection in a city with a Chinese-influenced architectural',
348+
}
349+
) # fmt: skip
350+
EXPECTED_RESPONSE = EXPECTED_RESPONSES.get_expectation()
336351

337352
self.assertEqual(response, EXPECTED_RESPONSE)
338353

354+
@require_torch_large_accelerator
339355
def test_multi_image_vision_text_generation(self):
340356
model = AutoModelForCausalLM.from_pretrained(
341-
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
357+
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
342358
)
343359

344360
images = []
@@ -365,7 +381,7 @@ def test_multi_image_vision_text_generation(self):
365381
@require_soundfile
366382
def test_audio_text_generation(self):
367383
model = AutoModelForCausalLM.from_pretrained(
368-
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
384+
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
369385
)
370386

371387
prompt = f"{self.user_token}<|audio|>What is happening in this audio?{self.end_token}{self.assistant_token}"

0 commit comments

Comments
 (0)