-
Notifications
You must be signed in to change notification settings - Fork 18
✅ CB tests refactoring + adding batch test #257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Prashant Gupta <[email protected]>
Signed-off-by: Prashant Gupta <[email protected]>
👋 Hi! Thank you for contributing to vLLM support on Spyre.
Or this can be done with
Now you are good to go 🚀 |
performance improvement Signed-off-by: Prashant Gupta <[email protected]>
Signed-off-by: Prashant Gupta <[email protected]>
Signed-off-by: Prashant Gupta <[email protected]>
Signed-off-by: Prashant Gupta <[email protected]>
Signed-off-by: Prashant Gupta <[email protected]>
Signed-off-by: Prashant Gupta <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm in general, provided some comments which would be nice to be addressed before merging.
@@ -52,6 +38,7 @@ def test_output( | |||
test using 'pytest --capture=no tests/spyre/test_spyre_basic.py' | |||
After debugging, DISABLE_ASSERTS should be reset to 'False'. | |||
''' | |||
prompts = get_chicken_soup_prompts(4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea to not have prompts as parameters!
tests/e2e/test_spyre_basic.py
Outdated
sampling_params=vllm_sampling_params, | ||
tensor_parallel_size=1, | ||
backend=backend, | ||
monkeypatch=monkeypatch) | ||
max_num_seqs=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this be packed into kwargs
only in case cb == 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably!
prompts = [ | ||
"7 6 5 4", | ||
"10 9 8 7", | ||
"8 7 6 5", | ||
"10 9 8 7 ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice to switch to comparing to hf outputs here too!
template.format("Convert char to string in Java."), | ||
]]) | ||
def test_cb_handling( | ||
@pytest.mark.parametrize("max_num_seqs", [2, 4], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the motivation of removing max_num_seqs
3 here? it will be expected to fail with xfail anyway, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just reducing the total number of tests running, since these all take quite a while to run and I didn't think there was anything specific about max_num_seqs=3
that we needed to test.
Do you think there's a good chance we'll miss a bug if we don't run with 3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
@@ -85,20 +72,18 @@ def test_cb_handling( | |||
|
|||
|
|||
@pytest.mark.cb | |||
@pytest.mark.parametrize("max_num_seqs", [2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could also stay here as we want to support batch size > 2 soon (and test it here). Could use xfail
similar in the above test_cb_output
.
Or is it not required to also test for batch size > 2 for test_cb_max_tokens
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or is it not required to also test for batch size > 2 for test_cb_max_tokens?
Right, we're testing that the prompts are rejected before even running the model so I don' think it's relevant to parameterize this on max batch size.
@@ -643,7 +628,6 @@ def augment_checked_steps( | |||
@pytest.mark.cb | |||
@pytest.mark.parametrize("model", get_spyre_model_list()) | |||
@pytest.mark.parametrize("backend", get_spyre_backend_list()) | |||
@pytest.mark.parametrize("max_num_seqs", [2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question (see above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, my understanding here is that the get_params_test_* methods are all already assuming max batch size 2, so this can't be parameterized higher right now
tests/e2e/test_spyre_cb.py
Outdated
@@ -643,7 +628,6 @@ def augment_checked_steps( | |||
@pytest.mark.cb | |||
@pytest.mark.parametrize("model", get_spyre_model_list()) | |||
@pytest.mark.parametrize("backend", get_spyre_backend_list()) | |||
@pytest.mark.parametrize("max_num_seqs", [2]) | |||
@pytest.mark.parametrize( | |||
"seqs_max_tokens,prompts_lengths,steps_add_reqs,checked_steps," | |||
"max_model_len", [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I know this is not your change, but) could we move max_model_len
out of the get_params*
functions as part of this refactoring PR? It is not test case specific and should not be needed to set at 5 different places (currently it is set to 256 in all functions)
Signed-off-by: Joe Runde <[email protected]>
mergin'! |
Description
max_model_len
to256
pytest.xfail
for failing CB testsIssues