|
| 1 | +from typing import List, Optional, Tuple |
| 2 | +import unittest |
| 3 | +from unittest.mock import patch |
| 4 | + |
| 5 | +from absl.testing import absltest |
| 6 | +from absl.testing import parameterized |
| 7 | +import jax |
| 8 | +from jax._src import test_util as jtu |
| 9 | +from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import ( |
| 10 | + quantized_matmul_int8, |
| 11 | + get_tuned_block_sizes, |
| 12 | + TUNED_BLOCK_SIZES, |
| 13 | +) |
| 14 | +import jax.numpy as jnp |
| 15 | +import numpy as np |
| 16 | + |
| 17 | +jax.config.parse_flags_with_absl() |
| 18 | + |
| 19 | + |
| 20 | +def quantize_array(x, n_bits: int = 8, dim: int = -1): |
| 21 | + max_val = jnp.max(jnp.abs(x), axis=dim, keepdims=True) |
| 22 | + int_min = -2**(n_bits - 1) |
| 23 | + int_max = 2**(n_bits - 1) - 1 |
| 24 | + scale = max_val / int_max |
| 25 | + x_int = jnp.clip(jnp.round((x / scale)), int_min, int_max).astype(jnp.int8) |
| 26 | + return x_int, scale.astype(x.dtype) |
| 27 | + |
| 28 | + |
| 29 | +@jtu.with_config(jax_numpy_dtype_promotion="standard") |
| 30 | +class QuantizedMatmulKernelTest(jtu.JaxTestCase): |
| 31 | + |
| 32 | + def setUp(self): |
| 33 | + super().setUp() |
| 34 | + if not jtu.is_device_tpu_at_least(5): |
| 35 | + self.skipTest( |
| 36 | + 'This kernel requires a Mosaic feature not available for TPU v4 or earlier.' |
| 37 | + ) |
| 38 | + |
| 39 | + def _test_quantized_matmul(self, |
| 40 | + dtype, |
| 41 | + bs, |
| 42 | + n_input_features, |
| 43 | + n_output_features, |
| 44 | + quantize_activation, |
| 45 | + batch_block_size=None, |
| 46 | + out_block_size=None, |
| 47 | + in_block_size=None, |
| 48 | + atol=1.5): |
| 49 | + |
| 50 | + prng_key = jax.random.key(1234) |
| 51 | + k0, k1 = jax.random.split(prng_key, 2) |
| 52 | + x = jax.random.normal(k0, (bs, n_input_features), dtype=dtype) |
| 53 | + w = jax.random.normal( |
| 54 | + k1, (n_output_features, n_input_features), dtype=dtype) |
| 55 | + x_copy = x.copy() |
| 56 | + w_copy = w.copy() |
| 57 | + q_w, scalar_w = quantize_array(w) |
| 58 | + scalar_w = jnp.squeeze(scalar_w) |
| 59 | + assert scalar_w.shape == (n_output_features,) |
| 60 | + |
| 61 | + output = quantized_matmul_int8( |
| 62 | + x, |
| 63 | + q_w, |
| 64 | + scalar_w, |
| 65 | + quantize_activation=quantize_activation, |
| 66 | + batch_block_size=batch_block_size, |
| 67 | + out_block_size=out_block_size, |
| 68 | + in_block_size=in_block_size).block_until_ready() |
| 69 | + expected = jax.lax.dot_general( |
| 70 | + x_copy, w_copy, dimension_numbers=(((1,), (1,)), ((), ()))) |
| 71 | + |
| 72 | + self.assertEqual(output.dtype, expected.dtype) |
| 73 | + self.assertEqual(output.shape, expected.shape) |
| 74 | + self.assertAllClose(output, expected, atol=atol) |
| 75 | + |
| 76 | + @parameterized.product( |
| 77 | + dtype=[jnp.bfloat16, jnp.float32], |
| 78 | + bs=[128, 256, 512], |
| 79 | + n_input_features=[128, 256, 512], |
| 80 | + n_output_features=[128, 256, 512], |
| 81 | + quantize_activation=[True], |
| 82 | + ) |
| 83 | + def test_quantized_matmul_various_input_shapes(self, dtype, bs, |
| 84 | + n_input_features, |
| 85 | + n_output_features, |
| 86 | + quantize_activation): |
| 87 | + self._test_quantized_matmul( |
| 88 | + dtype, |
| 89 | + bs, |
| 90 | + n_input_features, |
| 91 | + n_output_features, |
| 92 | + quantize_activation=quantize_activation, |
| 93 | + batch_block_size=128, |
| 94 | + out_block_size=128, |
| 95 | + in_block_size=128) |
| 96 | + |
| 97 | + @parameterized.product( |
| 98 | + dtype=[jnp.bfloat16, jnp.float32], |
| 99 | + bs=[64, 192], |
| 100 | + n_input_features=[64, 192], |
| 101 | + n_output_features=[64, 192], |
| 102 | + quantize_activation=[True], |
| 103 | + ) |
| 104 | + def test_quantized_matmul_unaligned_input_shapes(self, dtype, bs, |
| 105 | + n_input_features, |
| 106 | + n_output_features, |
| 107 | + quantize_activation): |
| 108 | + self._test_quantized_matmul( |
| 109 | + dtype, |
| 110 | + bs, |
| 111 | + n_input_features, |
| 112 | + n_output_features, |
| 113 | + quantize_activation=quantize_activation, |
| 114 | + batch_block_size=128, |
| 115 | + out_block_size=128, |
| 116 | + in_block_size=128) |
| 117 | + |
| 118 | + @patch( |
| 119 | + 'torch_xla.experimental.pallas_kernels.quantized_matmul_kernel.get_tpu_version' |
| 120 | + ) |
| 121 | + def test_quantized_matmul_retrieve_block_sizes(self, get_tpu_version): |
| 122 | + tpu_version_to_use = 6 |
| 123 | + get_tpu_version.return_value = tpu_version_to_use |
| 124 | + key0 = None |
| 125 | + for key, expected_block_sizes in TUNED_BLOCK_SIZES.items(): |
| 126 | + if key[0] == tpu_version_to_use: |
| 127 | + key0 = key |
| 128 | + break |
| 129 | + expected_block_sizes = TUNED_BLOCK_SIZES[key0] |
| 130 | + _, bs, n_output_features, n_input_features, activation_dtype, quantize_activation = key0 |
| 131 | + actual_block_sizes = get_tuned_block_sizes(bs, n_output_features, |
| 132 | + n_input_features, |
| 133 | + activation_dtype, |
| 134 | + quantize_activation) |
| 135 | + assert actual_block_sizes == expected_block_sizes, f"Expected block sizes {expected_block_sizes}, but got {actual_block_sizes} for key {key0}" |
| 136 | + |
| 137 | + @parameterized.product( |
| 138 | + dtype=[jnp.bfloat16], |
| 139 | + bs=[16], |
| 140 | + n_input_features=[128, 256], |
| 141 | + n_output_features=[128, 256], |
| 142 | + quantize_activation=[True], |
| 143 | + ) |
| 144 | + def test_quantized_matmul_use_tuned_block_sizes(self, dtype, bs, |
| 145 | + n_input_features, |
| 146 | + n_output_features, |
| 147 | + quantize_activation): |
| 148 | + self._test_quantized_matmul( |
| 149 | + dtype, |
| 150 | + bs, |
| 151 | + n_input_features, |
| 152 | + n_output_features, |
| 153 | + quantize_activation=quantize_activation) |
| 154 | + |
| 155 | + |
| 156 | +if __name__ == "__main__": |
| 157 | + absltest.main(testLoader=jtu.JaxTestLoader()) |
0 commit comments