@@ -118,3 +118,33 @@ def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, embeddings,
118
118
qout = torch .nn .functional .linear (inputs , marlin_qweight , bias )
119
119
out = torch .nn .functional .linear (inputs , qbt .dequantize (), bias )
120
120
assert_similar (out , qout )
121
+
122
+
123
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
124
+ @pytest .mark .parametrize ("tokens" , [16 , 32 , 33 ])
125
+ def test_marlin_int4_weight_qbits_tensor_linear_bug (tokens ):
126
+ device = torch .device ("cuda" )
127
+ dtype = torch .float16
128
+ weight_qtype = qint4
129
+ group_size = 128
130
+ in_features = 4096
131
+ out_features = 2048
132
+ inputs = torch .rand ((tokens , in_features ), dtype = dtype , device = device )
133
+ # Create a MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA
134
+ qbt = random_qweight ((out_features , in_features ), weight_qtype , dtype , group_size = group_size , device = torch .device ("cuda" ))
135
+ marlin_qweight = MarlinInt4WeightQBitsTensor (
136
+ qtype = qbt .qtype ,
137
+ axis = qbt .axis ,
138
+ group_size = qbt ._group_size ,
139
+ size = qbt .size (),
140
+ stride = qbt .stride (),
141
+ data = qbt ._data .unpack (),
142
+ scale = qbt ._scale ,
143
+ shift = qbt ._shift ,
144
+ )
145
+ qout = torch .nn .functional .linear (inputs , marlin_qweight , bias = None )
146
+ out = torch .nn .functional .linear (inputs , qbt .dequantize (), bias = None )
147
+ max_val = out .abs ().max ()
148
+ max_err = (out - qout ).abs ().max ()
149
+ print (max_val , max_err )
150
+ assert max_err / max_val < 1e-2
0 commit comments