Skip to content

Commit

Permalink
add xpu support for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany committed Nov 6, 2024
1 parent 7aaf99e commit 35fa8e1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
devices += ["cuda"]
elif torch.backends.mps.is_available():
devices += ["mps"]
elif torch.xpu.is_available():
devices += ["xpu"]


@pytest.fixture(scope="module", params=devices)
Expand Down
3 changes: 3 additions & 0 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def get_device_memory(device):
elif device.type == "mps":
torch.mps.empty_cache()
return torch.mps.current_allocated_memory()
elif device.type == "xpu":
torch.xpu.empty_cache()
return torch.xpu.memory_allocated()
return None


Expand Down
2 changes: 1 addition & 1 deletion test/tensor/weights/weight_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def check_weight_qtensor_linear(qweight, batch_size, tokens, use_bias, rel_max_e
max_err = (out - qout).abs().max()
rel_max_err = max_err / mean_val
# These values were evaluated empirically without any optimized kernels.
rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2}[device.type]
rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2, "xpu": 2e-2}[device.type]
assert (
rel_max_err < rtol
), f"Maximum error {max_err:.2f} is too high for input of mean value {mean_val:.2f} ({rel_max_err*100:.2f} %)"

0 comments on commit 35fa8e1

Please sign in to comment.