From 5d3142d6ed1287c30da7534f24d00564f92f13ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl?= Date: Wed, 11 Dec 2024 10:29:31 +0000 Subject: [PATCH] fix(test): correctly handle beam_search in generate text --- tests/generate/test_generate.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 7f5f108d1..3615a3605 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -279,10 +279,19 @@ def get_inputs(fixture_name, batch_size=None): @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) def test_generate_text(request, model_fixture, sampler_name): model = request.getfixturevalue(model_fixture) - generator = generate.text(model, getattr(samplers, sampler_name)()) with enforce_not_implemented(model_fixture, sampler_name): - res = generator(**get_inputs(model_fixture), max_tokens=10) - assert isinstance(res, str) + if sampler_name == "beam_search": + num_head = 2 + generator = generate.text(model, getattr(samplers, sampler_name)(num_head)) + res = generator(**get_inputs(model_fixture), max_tokens=10) + assert isinstance(res, list) + assert len(res) == num_head + for elt in res: + assert isinstance(elt, str) + else: + generator = generate.text(model, getattr(samplers, sampler_name)()) + res = generator(**get_inputs(model_fixture), max_tokens=10) + assert isinstance(res, str) @pytest.mark.parametrize("pattern", REGEX_PATTERNS)