Skip to content

enable torchao test cases on XPU and switch to device agnostic APIs for test cases #11654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

yao-matrix
Copy link
Contributor

@yao-matrix yao-matrix commented Jun 4, 2025

  1. enable torchao cases on XPU, torchao has official XPU support from 0.11. Tested in local, same pass rate as CUDA
  2. switch to use device-agnostic APIs for all test cases.

@a-r-r-o-w @DN6 , pls help review, thx very much

Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
@@ -269,6 +271,7 @@ def test_int4wo_quant_bfloat16_conversion(self):
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map=f"{torch_device}:0",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if load model to CPU, will meet "NotImplementedError: Could not run 'aten::_convert_weight_to_int4pack' with arguments from the 'CPU' backend." error, this both happens in XPU and CUDA, so directly load model to accelerator here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, we recently got this error on our CI too

@@ -79,7 +81,7 @@ def test_post_init_check(self):
Test kwargs validations in TorchAoConfig
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"):
with self.assertRaisesRegex(ValueError, "is not supported"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xpu's error message doesn't have "yet", so just match "is not supported" for both CUDA and XPU

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. IIRC this test is failing in our cuda CI too since the error message does not "yet" any more, so this should be okay

if not torch.cuda.is_available():
return unittest.skip(test_case)
else:
if torch.cuda.is_available():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only check cuda device compatibility, if non-cuda device, just pass. For non-cuda device which needs compatibility, should check by themselves.

return minor >= 9
return major >= 9
else:
return True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only check device capability when it's cuda; for non-cuda device, should check in separate utilities. In this case, non-cuda device(like XPU)'s case will be skipped by original implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably still raise an error if torchao is being used with mps or other devices, otherwise it leads to an obscure error somewhere deep in the code that common users will not understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w , i enhanced this utility per your comments, pls help review again, thx.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, changes look good! Would just make sure that we error out early for devices that are not cuda/xpu since otherwise the errors are much more complicated for users to understand

return minor >= 9
return major >= 9
else:
return True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably still raise an error if torchao is being used with mps or other devices, otherwise it leads to an obscure error somewhere deep in the code that common users will not understand

@@ -79,7 +81,7 @@ def test_post_init_check(self):
Test kwargs validations in TorchAoConfig
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"):
with self.assertRaisesRegex(ValueError, "is not supported"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. IIRC this test is failing in our cuda CI too since the error message does not "yet" any more, so this should be okay

@@ -269,6 +271,7 @@ def test_int4wo_quant_bfloat16_conversion(self):
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map=f"{torch_device}:0",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, we recently got this error on our CI too

@a-r-r-o-w a-r-r-o-w requested a review from DN6 June 4, 2025 10:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants