-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Test out thunder.jit w/ NeMo models. #1694
base: main
Are you sure you want to change the base?
Conversation
@@ -445,7 +454,8 @@ def test_hf_for_nemo(model_id): | |||
# fullgraph=True used to work with transformers 4.45.2, but it doesn't work | |||
# with 4.46.2 because of re.findall usage in the loss function | |||
fullgraph = False | |||
compiled_model = thunderfx(model, fullgraph=fullgraph) | |||
# compiled_model = thunderfx(model, fullgraph=fullgraph) | |||
compiled_model = thunder.jit(model, fullgraph=fullgraph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wasn't this failing due to unsupported argument as I thought thunder.jit
doesn't have fullgraph
argument?
Or **compile_options
takes keyword argument so there wouldn't be errors for unsupported args?
compiled_model = thunder.jit(model, fullgraph=fullgraph) | |
compiled_model = thunder.jit(model, fullgraph=fullgraph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. One can see the logs by clicking through the CI links below:
=========================== short test summary info ============================
FAILED thunder/tests/test_networks.py::test_hf_for_nemo[bigcode/starcoder2-7b] - AssertionError
FAILED thunder/tests/test_networks.py::test_hf_for_nemo[microsoft/Phi-3-mini-128k-instruct] - AssertionError: expected tensor with (48,), cuda:0, torch.float32, requires_grad=False, got (1,), cuda:0, torch.bfloat16, False
FAILED thunder/tests/test_networks.py::test_thunderfx_mistral_nemo_small - AssertionError
============ 3 failed, 29 passed, 172 warnings in 172.63s (0:02:52) ============
so I guess your theory is correct that we do not error out due to unsupported (kw)args.
I have a vague memory of us discussing doing that though; maybe it's just a warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is one of the things I dislike about the options.
What does this PR do?
Demonstrates how to test thunder.jit with NeMo models.
PR review
This isn't meant to be merged or reviewed. If we actually want to do this, someone would want to extend the testing to test both
thunder.jit
andthunder.dynamo.thunderfx
, not remove the thunderfx testing (as done here).Did you have fun?
With this group, always :-)