Skip to content

Commit

Permalink
test: illustrate int4 marlin kernel bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Oct 6, 2024
1 parent 476a9dd commit 9bfb77e
Showing 1 changed file with 30 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,33 @@ def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, embeddings,
qout = torch.nn.functional.linear(inputs, marlin_qweight, bias)
out = torch.nn.functional.linear(inputs, qbt.dequantize(), bias)
assert_similar(out, qout)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("tokens", [16, 32, 33])
def test_marlin_int4_weight_qbits_tensor_linear_bug(tokens):
device = torch.device("cuda")
dtype = torch.float16
weight_qtype = qint4
group_size = 128
in_features = 4096
out_features = 2048
inputs = torch.rand((tokens, in_features), dtype=dtype, device=device)
# Create a MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA
qbt = random_qweight((out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda"))
marlin_qweight = MarlinInt4WeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
group_size=qbt._group_size,
size=qbt.size(),
stride=qbt.stride(),
data=qbt._data.unpack(),
scale=qbt._scale,
shift=qbt._shift,
)
qout = torch.nn.functional.linear(inputs, marlin_qweight, bias=None)
out = torch.nn.functional.linear(inputs, qbt.dequantize(), bias=None)
max_val = out.abs().max()
max_err = (out - qout).abs().max()
print(max_val, max_err)
assert max_err / max_val < 1e-2

0 comments on commit 9bfb77e

Please sign in to comment.