Skip to content

Commit 35fa8e1

Browse files
committed
add xpu support for testing
1 parent 7aaf99e commit 35fa8e1

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

test/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
devices += ["cuda"]
2222
elif torch.backends.mps.is_available():
2323
devices += ["mps"]
24+
elif torch.xpu.is_available():
25+
devices += ["xpu"]
2426

2527

2628
@pytest.fixture(scope="module", params=devices)

test/helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def get_device_memory(device):
103103
elif device.type == "mps":
104104
torch.mps.empty_cache()
105105
return torch.mps.current_allocated_memory()
106+
elif device.type == "xpu":
107+
torch.xpu.empty_cache()
108+
return torch.xpu.memory_allocated()
106109
return None
107110

108111

test/tensor/weights/weight_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def check_weight_qtensor_linear(qweight, batch_size, tokens, use_bias, rel_max_e
3131
max_err = (out - qout).abs().max()
3232
rel_max_err = max_err / mean_val
3333
# These values were evaluated empirically without any optimized kernels.
34-
rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2}[device.type]
34+
rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2, "xpu": 2e-2}[device.type]
3535
assert (
3636
rel_max_err < rtol
3737
), f"Maximum error {max_err:.2f} is too high for input of mean value {mean_val:.2f} ({rel_max_err*100:.2f} %)"

0 commit comments

Comments
 (0)