-
Notifications
You must be signed in to change notification settings - Fork 6k
enable torchao test cases on XPU and switch to device agnostic APIs for test cases #11654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Matrix YAO <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
@@ -269,6 +271,7 @@ def test_int4wo_quant_bfloat16_conversion(self): | |||
subfolder="transformer", | |||
quantization_config=quantization_config, | |||
torch_dtype=torch.bfloat16, | |||
device_map=f"{torch_device}:0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if load model to CPU, will meet "NotImplementedError: Could not run 'aten::_convert_weight_to_int4pack' with arguments from the 'CPU' backend." error, this both happens in XPU and CUDA, so directly load model to accelerator here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, we recently got this error on our CI too
@@ -79,7 +81,7 @@ def test_post_init_check(self): | |||
Test kwargs validations in TorchAoConfig | |||
""" | |||
_ = TorchAoConfig("int4_weight_only") | |||
with self.assertRaisesRegex(ValueError, "is not supported yet"): | |||
with self.assertRaisesRegex(ValueError, "is not supported"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xpu's error message doesn't have "yet", so just match "is not supported" for both CUDA and XPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. IIRC this test is failing in our cuda CI too since the error message does not "yet" any more, so this should be okay
if not torch.cuda.is_available(): | ||
return unittest.skip(test_case) | ||
else: | ||
if torch.cuda.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only check cuda device compatibility, if non-cuda device, just pass. For non-cuda device which needs compatibility, should check by themselves.
return minor >= 9 | ||
return major >= 9 | ||
else: | ||
return True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only check device capability when it's cuda; for non-cuda device, should check in separate utilities. In this case, non-cuda device(like XPU)'s case will be skipped by original implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably still raise an error if torchao is being used with mps or other devices, otherwise it leads to an obscure error somewhere deep in the code that common users will not understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w , i enhanced this utility per your comments, pls help review again, thx.
Signed-off-by: YAO Matrix <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, changes look good! Would just make sure that we error out early for devices that are not cuda/xpu since otherwise the errors are much more complicated for users to understand
return minor >= 9 | ||
return major >= 9 | ||
else: | ||
return True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably still raise an error if torchao is being used with mps or other devices, otherwise it leads to an obscure error somewhere deep in the code that common users will not understand
@@ -79,7 +81,7 @@ def test_post_init_check(self): | |||
Test kwargs validations in TorchAoConfig | |||
""" | |||
_ = TorchAoConfig("int4_weight_only") | |||
with self.assertRaisesRegex(ValueError, "is not supported yet"): | |||
with self.assertRaisesRegex(ValueError, "is not supported"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. IIRC this test is failing in our cuda CI too since the error message does not "yet" any more, so this should be okay
@@ -269,6 +271,7 @@ def test_int4wo_quant_bfloat16_conversion(self): | |||
subfolder="transformer", | |||
quantization_config=quantization_config, | |||
torch_dtype=torch.bfloat16, | |||
device_map=f"{torch_device}:0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, we recently got this error on our CI too
Signed-off-by: YAO Matrix <[email protected]>
torchao
cases on XPU, torchao has official XPU support from0.11
. Tested in local, same pass rate as CUDA@a-r-r-o-w @DN6 , pls help review, thx very much