diff --git a/optimum/quanto/library/qbytes_mm.py b/optimum/quanto/library/qbytes_mm.py index 62da6f2e..22319345 100644 --- a/optimum/quanto/library/qbytes_mm.py +++ b/optimum/quanto/library/qbytes_mm.py @@ -92,7 +92,7 @@ def qbytes_mm_impl_cuda(activations: torch.Tensor, weights: torch.Tensor, output def qbytes_mm_impl_cpu(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: if ( # FIXME: accuracy issues with 2.4.x - version.parse(torch.__version__).release > version.parse("2.5.0").release + version.parse(torch.__version__).release >= version.parse("2.6.0").release and activations.dtype == torch.int8 and weights.dtype == torch.int8 ): diff --git a/test/conftest.py b/test/conftest.py index d8bb234c..5e9b4367 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -21,6 +21,8 @@ devices += ["cuda"] elif torch.backends.mps.is_available(): devices += ["mps"] +elif torch.xpu.is_available(): + devices += ["xpu"] @pytest.fixture(scope="module", params=devices) diff --git a/test/helpers.py b/test/helpers.py index 1693d6a9..3e6635f3 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -103,6 +103,9 @@ def get_device_memory(device): elif device.type == "mps": torch.mps.empty_cache() return torch.mps.current_allocated_memory() + elif device.type == "xpu": + torch.xpu.empty_cache() + return torch.xpu.memory_allocated() return None diff --git a/test/tensor/weights/weight_helpers.py b/test/tensor/weights/weight_helpers.py index 761cbea3..762836e7 100644 --- a/test/tensor/weights/weight_helpers.py +++ b/test/tensor/weights/weight_helpers.py @@ -31,7 +31,7 @@ def check_weight_qtensor_linear(qweight, batch_size, tokens, use_bias, rel_max_e max_err = (out - qout).abs().max() rel_max_err = max_err / mean_val # These values were evaluated empirically without any optimized kernels. - rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2}[device.type] + rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2, "xpu": 2e-2}[device.type] assert ( rel_max_err < rtol ), f"Maximum error {max_err:.2f} is too high for input of mean value {mean_val:.2f} ({rel_max_err*100:.2f} %)"