Skip to content

Commit 2504888

Browse files
authored
Add w8a8 quantized matmul kernel (#9278)
1 parent e51af25 commit 2504888

File tree

2 files changed

+467
-0
lines changed

2 files changed

+467
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from typing import List, Optional, Tuple
2+
import unittest
3+
from unittest.mock import patch
4+
5+
from absl.testing import absltest
6+
from absl.testing import parameterized
7+
import jax
8+
from jax._src import test_util as jtu
9+
from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import (
10+
quantized_matmul_int8,
11+
get_tuned_block_sizes,
12+
TUNED_BLOCK_SIZES,
13+
)
14+
import jax.numpy as jnp
15+
import numpy as np
16+
17+
jax.config.parse_flags_with_absl()
18+
19+
20+
def quantize_array(x, n_bits: int = 8, dim: int = -1):
21+
max_val = jnp.max(jnp.abs(x), axis=dim, keepdims=True)
22+
int_min = -2**(n_bits - 1)
23+
int_max = 2**(n_bits - 1) - 1
24+
scale = max_val / int_max
25+
x_int = jnp.clip(jnp.round((x / scale)), int_min, int_max).astype(jnp.int8)
26+
return x_int, scale.astype(x.dtype)
27+
28+
29+
@jtu.with_config(jax_numpy_dtype_promotion="standard")
30+
class QuantizedMatmulKernelTest(jtu.JaxTestCase):
31+
32+
def setUp(self):
33+
super().setUp()
34+
if not jtu.is_device_tpu_at_least(5):
35+
self.skipTest(
36+
'This kernel requires a Mosaic feature not available for TPU v4 or earlier.'
37+
)
38+
39+
def _test_quantized_matmul(self,
40+
dtype,
41+
bs,
42+
n_input_features,
43+
n_output_features,
44+
quantize_activation,
45+
batch_block_size=None,
46+
out_block_size=None,
47+
in_block_size=None,
48+
atol=1.5):
49+
50+
prng_key = jax.random.key(1234)
51+
k0, k1 = jax.random.split(prng_key, 2)
52+
x = jax.random.normal(k0, (bs, n_input_features), dtype=dtype)
53+
w = jax.random.normal(
54+
k1, (n_output_features, n_input_features), dtype=dtype)
55+
x_copy = x.copy()
56+
w_copy = w.copy()
57+
q_w, scalar_w = quantize_array(w)
58+
scalar_w = jnp.squeeze(scalar_w)
59+
assert scalar_w.shape == (n_output_features,)
60+
61+
output = quantized_matmul_int8(
62+
x,
63+
q_w,
64+
scalar_w,
65+
quantize_activation=quantize_activation,
66+
batch_block_size=batch_block_size,
67+
out_block_size=out_block_size,
68+
in_block_size=in_block_size).block_until_ready()
69+
expected = jax.lax.dot_general(
70+
x_copy, w_copy, dimension_numbers=(((1,), (1,)), ((), ())))
71+
72+
self.assertEqual(output.dtype, expected.dtype)
73+
self.assertEqual(output.shape, expected.shape)
74+
self.assertAllClose(output, expected, atol=atol)
75+
76+
@parameterized.product(
77+
dtype=[jnp.bfloat16, jnp.float32],
78+
bs=[128, 256, 512],
79+
n_input_features=[128, 256, 512],
80+
n_output_features=[128, 256, 512],
81+
quantize_activation=[True],
82+
)
83+
def test_quantized_matmul_various_input_shapes(self, dtype, bs,
84+
n_input_features,
85+
n_output_features,
86+
quantize_activation):
87+
self._test_quantized_matmul(
88+
dtype,
89+
bs,
90+
n_input_features,
91+
n_output_features,
92+
quantize_activation=quantize_activation,
93+
batch_block_size=128,
94+
out_block_size=128,
95+
in_block_size=128)
96+
97+
@parameterized.product(
98+
dtype=[jnp.bfloat16, jnp.float32],
99+
bs=[64, 192],
100+
n_input_features=[64, 192],
101+
n_output_features=[64, 192],
102+
quantize_activation=[True],
103+
)
104+
def test_quantized_matmul_unaligned_input_shapes(self, dtype, bs,
105+
n_input_features,
106+
n_output_features,
107+
quantize_activation):
108+
self._test_quantized_matmul(
109+
dtype,
110+
bs,
111+
n_input_features,
112+
n_output_features,
113+
quantize_activation=quantize_activation,
114+
batch_block_size=128,
115+
out_block_size=128,
116+
in_block_size=128)
117+
118+
@patch(
119+
'torch_xla.experimental.pallas_kernels.quantized_matmul_kernel.get_tpu_version'
120+
)
121+
def test_quantized_matmul_retrieve_block_sizes(self, get_tpu_version):
122+
tpu_version_to_use = 6
123+
get_tpu_version.return_value = tpu_version_to_use
124+
key0 = None
125+
for key, expected_block_sizes in TUNED_BLOCK_SIZES.items():
126+
if key[0] == tpu_version_to_use:
127+
key0 = key
128+
break
129+
expected_block_sizes = TUNED_BLOCK_SIZES[key0]
130+
_, bs, n_output_features, n_input_features, activation_dtype, quantize_activation = key0
131+
actual_block_sizes = get_tuned_block_sizes(bs, n_output_features,
132+
n_input_features,
133+
activation_dtype,
134+
quantize_activation)
135+
assert actual_block_sizes == expected_block_sizes, f"Expected block sizes {expected_block_sizes}, but got {actual_block_sizes} for key {key0}"
136+
137+
@parameterized.product(
138+
dtype=[jnp.bfloat16],
139+
bs=[16],
140+
n_input_features=[128, 256],
141+
n_output_features=[128, 256],
142+
quantize_activation=[True],
143+
)
144+
def test_quantized_matmul_use_tuned_block_sizes(self, dtype, bs,
145+
n_input_features,
146+
n_output_features,
147+
quantize_activation):
148+
self._test_quantized_matmul(
149+
dtype,
150+
bs,
151+
n_input_features,
152+
n_output_features,
153+
quantize_activation=quantize_activation)
154+
155+
156+
if __name__ == "__main__":
157+
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)