Skip to content

Commit

Permalink
fix(test): correctly handle beam_search in generate text
Browse files Browse the repository at this point in the history
  • Loading branch information
g-prz committed Dec 11, 2024
1 parent 69ec787 commit 5d3142d
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,19 @@ def get_inputs(fixture_name, batch_size=None):
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_text(request, model_fixture, sampler_name):
model = request.getfixturevalue(model_fixture)
generator = generate.text(model, getattr(samplers, sampler_name)())
with enforce_not_implemented(model_fixture, sampler_name):
res = generator(**get_inputs(model_fixture), max_tokens=10)
assert isinstance(res, str)
if sampler_name == "beam_search":
num_head = 2
generator = generate.text(model, getattr(samplers, sampler_name)(num_head))
res = generator(**get_inputs(model_fixture), max_tokens=10)
assert isinstance(res, list)
assert len(res) == num_head
for elt in res:
assert isinstance(elt, str)
else:
generator = generate.text(model, getattr(samplers, sampler_name)())
res = generator(**get_inputs(model_fixture), max_tokens=10)
assert isinstance(res, str)


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
Expand Down

0 comments on commit 5d3142d

Please sign in to comment.