Skip to content

Commit 2ce56cc

Browse files
committed
all tests pass.
1 parent 60a9b96 commit 2ce56cc

File tree

2 files changed

+116
-69
lines changed

2 files changed

+116
-69
lines changed

test/test_pallas.py

Lines changed: 37 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -890,11 +890,10 @@ def _test_quantized_matmul(
890890
in_block_size=None,
891891
atol=1.5,
892892
n_bits=8,
893-
):
893+
):
894894
x = torch.randn((bs, n_input_features), dtype=dtype)
895895
w = torch.randn((n_output_features, n_input_features), dtype=dtype)
896-
min_val, max_val = torch.aminmax(
897-
w, dim=1) # min_val, max_val [out_dim]
896+
min_val, max_val = torch.aminmax(w, dim=1) # min_val, max_val [out_dim]
898897
int_min = -2**(n_bits - 1)
899898
int_max = 2**(n_bits - 1) - 1
900899
scalar, zero_point = determine_qparams(
@@ -913,21 +912,30 @@ def _test_quantized_matmul(
913912
x_copy = x.clone()
914913
w_copy = w.clone()
915914
expected = F.linear(x_copy, w_copy)
916-
915+
917916
x_xla = x.to("xla")
918917
w_int_xla = w_int.to("xla")
919918
scalar_xla = scalar.to("xla")
920919
if use_dynamo:
921-
def quantized_matmul_wrapper(x, w_int, scalar):
922-
return torch.ops.xla.quantized_matmul(
923-
x, w_int, scalar, quantize_activation=quantize_activation, batch_block_size=batch_block_size,
924-
out_block_size=out_block_size, in_block_size=in_block_size)
925920

926-
quantized_matmul = torch.compile(quantized_matmul_wrapper, backend="openxla")
921+
def quantized_matmul_wrapper(x, w_int, scalar, quantize_activation,
922+
batch_block_size, out_block_size,
923+
in_block_size):
924+
return torch.ops.xla.quantized_matmul(
925+
x,
926+
w_int,
927+
scalar,
928+
quantize_activation=quantize_activation,
929+
batch_block_size=batch_block_size,
930+
out_block_size=out_block_size,
931+
in_block_size=in_block_size)
932+
933+
quantized_matmul = torch.compile(
934+
quantized_matmul_wrapper, backend="openxla")
927935
else:
928936
from torch_xla.experimental.custom_kernel import quantized_matmul
929937
quantized_matmul = quantized_matmul
930-
938+
931939
actual = quantized_matmul(
932940
x_xla,
933941
w_int_xla,
@@ -936,68 +944,43 @@ def quantized_matmul_wrapper(x, w_int, scalar):
936944
batch_block_size=batch_block_size,
937945
out_block_size=out_block_size,
938946
in_block_size=in_block_size).cpu()
939-
947+
940948
self.assertEqual(actual.shape, expected.shape)
941949
self.assertEqual(actual.dtype, expected.dtype)
942-
self.assertTrue(
943-
torch.allclose(
944-
actual, expected, atol=atol))
950+
self.assertTrue(torch.allclose(actual, expected, atol=atol))
945951

946-
947-
@parameterized.product(
948-
seq_lens=[[(1, 1328), (5, 18), (500, 563)]],
949-
num_heads=[(32, 8), (8, 1)],
950-
dtype=[(torch.bfloat16, torch.bfloat16),
951-
(torch.bfloat16, torch.float8_e5m2)],
952-
sm_scale=[1.0, 0.5],
953-
sliding_window=[None, 128],
954-
soft_cap=[None, 10.0],
955-
pad_tokens_and_seqs=[False, True])
956-
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
957-
"This test only works on TPUv4+.")
958-
def test_quantized_matmul_with_dynamo(
959-
self,
960-
seq_lens,
961-
num_heads,
962-
dtype,
963-
sm_scale,
964-
sliding_window,
965-
soft_cap,
966-
pad_tokens_and_seqs,
967-
):
968-
...
969-
970-
# @parameterized.product(
971-
# dtype=[torch.bfloat16],
972-
# bs=[128],
973-
# n_input_features=[128],
974-
# n_output_features=[128],
975-
# quantize_activation=[True],
976-
# # block_sizes=[(None, None, None), (128, 128, 128)],
977-
# kernel_block_sizes=[(128, 128, 128)],
978-
# )
979952
@parameterized.product(
980953
dtype=[torch.bfloat16, torch.float32],
981-
bs=[128, 256],
982-
n_input_features=[128, 256],
983-
n_output_features=[128, 256],
954+
bs=[256, 512],
955+
n_input_features=[256, 512],
956+
n_output_features=[256, 512],
984957
quantize_activation=[True],
985-
# block_sizes=[(None, None, None), (128, 128, 128)],
986-
kernel_block_sizes=[(128, 128, 128)],
958+
kernel_block_sizes=[(None, None, None), (256, 256, 256)],
959+
use_dynamo=[True, False],
987960
)
988961
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 5,
989962
"This test only works on TPUv5+.")
990-
def test_quantized_matmul_wrapper_without_dynamo(
963+
def test_quantized_matmul_wrapper(
991964
self,
992965
dtype,
993966
bs,
994967
n_input_features,
995968
n_output_features,
996969
quantize_activation,
997970
kernel_block_sizes,
971+
use_dynamo,
998972
):
999973
batch_block_size, out_block_size, in_block_size = kernel_block_sizes
1000-
self._test_quantized_matmul(dtype, bs, n_input_features, n_output_features, quantize_activation, use_dynamo=False, batch_block_size=batch_block_size, out_block_size=out_block_size, in_block_size=in_block_size)
974+
self._test_quantized_matmul(
975+
dtype,
976+
bs,
977+
n_input_features,
978+
n_output_features,
979+
quantize_activation,
980+
use_dynamo=use_dynamo,
981+
batch_block_size=batch_block_size,
982+
out_block_size=out_block_size,
983+
in_block_size=in_block_size)
1001984

1002985
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
1003986
"This test only works on TPUv4+.")

torch_xla/experimental/custom_kernel.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,26 +1061,32 @@ def ragged_paged_attention(
10611061
])
10621062
return output[0]
10631063

1064+
10641065
@requires_jax
10651066
def quantized_matmul(
1066-
x: torch.Tensor,
1067-
w: torch.Tensor,
1068-
scalar: torch.Tensor,
1069-
zero_point: torch.Tensor | None = None,
1070-
block_size: torch.Tensor | None = None,
1071-
quantize_activation: bool = False,
1072-
batch_block_size: int | None = None,
1073-
out_block_size: int | None = None,
1074-
in_block_size: int | None = None,
1075-
vmem_limit_bytes: int | None = 64 * 1024 * 1024,
1067+
x: torch.Tensor,
1068+
w: torch.Tensor,
1069+
scalar: torch.Tensor,
1070+
zero_point: torch.Tensor | None = None,
1071+
block_size: torch.Tensor | None = None,
1072+
quantize_activation: bool = False,
1073+
batch_block_size: int | None = None,
1074+
out_block_size: int | None = None,
1075+
in_block_size: int | None = None,
1076+
vmem_limit_bytes: int | None = 64 * 1024 * 1024,
10761077
) -> torch.Tensor:
10771078
from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import quantized_matmul
10781079
return xb.call_jax(
1079-
quantized_matmul,
1080-
(x, w, scalar, zero_point, block_size, quantize_activation),
1081-
{"batch_block_size": batch_block_size, "out_block_size": out_block_size, "in_block_size": in_block_size, "vmem_limit_bytes": vmem_limit_bytes}
1082-
)
1083-
1080+
quantized_matmul, (x, w, scalar), {
1081+
"zero_point": zero_point,
1082+
"block_size": block_size,
1083+
"quantize_activation": quantize_activation,
1084+
"batch_block_size": batch_block_size,
1085+
"out_block_size": out_block_size,
1086+
"in_block_size": in_block_size,
1087+
"vmem_limit_bytes": vmem_limit_bytes
1088+
})
1089+
10841090

10851091
def _multi_queries_paged_attention_nonkernel(
10861092
q, # [batch_size, query_len, num_heads, head_size]
@@ -1646,3 +1652,61 @@ def gmm_non_xla(lhs: torch.Tensor,
16461652

16471653
# we only need to return the tensor with correct shape for meta tensor.
16481654
return torch.empty(lhs.size()[0], rhs_dim_size, device=lhs.device)
1655+
1656+
1657+
# @requires_jax
1658+
# def quantized_matmul(
1659+
# x: torch.Tensor,
1660+
# w: torch.Tensor,
1661+
# scalar: torch.Tensor,
1662+
# zero_point: torch.Tensor | None = None,
1663+
# block_size: torch.Tensor | None = None,
1664+
# quantize_activation: bool = False,
1665+
# batch_block_size: int | None = None,
1666+
# out_block_size: int | None = None,
1667+
# in_block_size: int | None = None,
1668+
# vmem_limit_bytes: int | None = 64 * 1024 * 1024,
1669+
# ) -> torch.Tensor:
1670+
1671+
XLA_LIB.define(
1672+
"quantized_matmul(Tensor x, Tensor w, Tensor scalar, Tensor? zero_point=None, Tensor? block_size=None, bool quantize_activation=False, int? batch_block_size=None, int? out_block_size=None, int? in_block_size=None, int? vmem_limit_bytes=None) -> Tensor",
1673+
)
1674+
1675+
1676+
@impl(XLA_LIB, "quantized_matmul", "XLA")
1677+
def quantized_matmul_xla(
1678+
x: torch.Tensor,
1679+
w: torch.Tensor,
1680+
scalar: torch.Tensor,
1681+
zero_point: torch.Tensor | None = None,
1682+
block_size: torch.Tensor | None = None,
1683+
quantize_activation: bool = False,
1684+
batch_block_size: int | None = None,
1685+
out_block_size: int | None = None,
1686+
in_block_size: int | None = None,
1687+
vmem_limit_bytes: int | None = 64 * 1024 * 1024,
1688+
) -> torch.Tensor:
1689+
return quantized_matmul(x, w, scalar, zero_point, block_size,
1690+
quantize_activation, batch_block_size, out_block_size,
1691+
in_block_size, vmem_limit_bytes)
1692+
1693+
1694+
@impl(XLA_LIB, "quantized_matmul", "CompositeExplicitAutograd")
1695+
def quantized_matmul_non_xla(
1696+
x: torch.Tensor,
1697+
w: torch.Tensor,
1698+
scalar: torch.Tensor,
1699+
zero_point: torch.Tensor | None = None,
1700+
block_size: torch.Tensor | None = None,
1701+
quantize_activation: bool = False,
1702+
batch_block_size: int | None = None,
1703+
out_block_size: int | None = None,
1704+
in_block_size: int | None = None,
1705+
vmem_limit_bytes: int | None = 64 * 1024 * 1024,
1706+
) -> torch.Tensor:
1707+
# This will be called when dynamo use fake tensor to construct the fake output.
1708+
# We need to make sure output tensor's shape is correct.
1709+
if x.device != torch.device("meta"):
1710+
warnings.warn(
1711+
f'XLA quantized_matmul should only be applied to tensors on XLA device')
1712+
return torch.empty(x.shape[0], w.shape[0], device=x.device)

0 commit comments

Comments
 (0)