Skip to content

Commit

Permalink
enable on xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany authored and dacorvo committed Nov 12, 2024
1 parent ad9d6f6 commit 4508702
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions test/tensor/weights/test_weight_qbits_tensor_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,20 @@ def test_weight_qbits_tensor_linear(dtype, batch_size, tokens, in_features, out_
check_weight_qtensor_linear(qbt, batch_size, tokens, use_bias)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is too slow on non-CUDA devices")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [16, 32, 48, 64])
@pytest.mark.parametrize("in_features", [1024, 4096, 16384])
@pytest.mark.parametrize("out_features", [1024, 2048, 4096])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
def test_weight_qbits_tensor_linear_cuda(dtype, batch_size, tokens, in_features, out_features, use_bias):
device = torch.device("cuda")
def test_weight_qbits_tensor_linear_gpu(dtype, batch_size, tokens, in_features, out_features, use_bias):
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.xpu.is_available():
device = torch.device("xpu")
else:
pytest.skip(reason="Test is too slow on non-GPU devices")

weight_qtype = qint4
group_size = 128
# Create a QBitsTensor
Expand Down

0 comments on commit 4508702

Please sign in to comment.