Skip to content

Commit 9bfb77e

Browse files
committed
test: illustrate int4 marlin kernel bug
1 parent 476a9dd commit 9bfb77e

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,33 @@ def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, embeddings,
118118
qout = torch.nn.functional.linear(inputs, marlin_qweight, bias)
119119
out = torch.nn.functional.linear(inputs, qbt.dequantize(), bias)
120120
assert_similar(out, qout)
121+
122+
123+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
124+
@pytest.mark.parametrize("tokens", [16, 32, 33])
125+
def test_marlin_int4_weight_qbits_tensor_linear_bug(tokens):
126+
device = torch.device("cuda")
127+
dtype = torch.float16
128+
weight_qtype = qint4
129+
group_size = 128
130+
in_features = 4096
131+
out_features = 2048
132+
inputs = torch.rand((tokens, in_features), dtype=dtype, device=device)
133+
# Create a MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA
134+
qbt = random_qweight((out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda"))
135+
marlin_qweight = MarlinInt4WeightQBitsTensor(
136+
qtype=qbt.qtype,
137+
axis=qbt.axis,
138+
group_size=qbt._group_size,
139+
size=qbt.size(),
140+
stride=qbt.stride(),
141+
data=qbt._data.unpack(),
142+
scale=qbt._scale,
143+
shift=qbt._shift,
144+
)
145+
qout = torch.nn.functional.linear(inputs, marlin_qweight, bias=None)
146+
out = torch.nn.functional.linear(inputs, qbt.dequantize(), bias=None)
147+
max_val = out.abs().max()
148+
max_err = (out - qout).abs().max()
149+
print(max_val, max_err)
150+
assert max_err / max_val < 1e-2

0 commit comments

Comments
 (0)