Skip to content

Commit 60a9b96

Browse files
committed
non-dynamo case passed.
1 parent f105d46 commit 60a9b96

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

test/test_pallas.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -939,8 +939,9 @@ def quantized_matmul_wrapper(x, w_int, scalar):
939939

940940
self.assertEqual(actual.shape, expected.shape)
941941
self.assertEqual(actual.dtype, expected.dtype)
942-
torch.testing.assert_close(
943-
actual, expected, atol=1.5)
942+
self.assertTrue(
943+
torch.allclose(
944+
actual, expected, atol=atol))
944945

945946

946947
@parameterized.product(
@@ -967,19 +968,19 @@ def test_quantized_matmul_with_dynamo(
967968
...
968969

969970
# @parameterized.product(
970-
# dtype=[torch.bfloat16, torch.float32],
971-
# bs=[128, 256],
972-
# n_input_features=[128, 256],
973-
# n_output_features=[128, 256],
971+
# dtype=[torch.bfloat16],
972+
# bs=[128],
973+
# n_input_features=[128],
974+
# n_output_features=[128],
974975
# quantize_activation=[True],
975976
# # block_sizes=[(None, None, None), (128, 128, 128)],
976-
# block_sizes=[(128, 128, 128)],
977+
# kernel_block_sizes=[(128, 128, 128)],
977978
# )
978979
@parameterized.product(
979-
dtype=[torch.bfloat16],
980-
bs=[128],
981-
n_input_features=[128],
982-
n_output_features=[128],
980+
dtype=[torch.bfloat16, torch.float32],
981+
bs=[128, 256],
982+
n_input_features=[128, 256],
983+
n_output_features=[128, 256],
983984
quantize_activation=[True],
984985
# block_sizes=[(None, None, None), (128, 128, 128)],
985986
kernel_block_sizes=[(128, 128, 128)],

torch_xla/experimental/custom_kernel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,6 @@ def quantized_matmul(
10681068
scalar: torch.Tensor,
10691069
zero_point: torch.Tensor | None = None,
10701070
block_size: torch.Tensor | None = None,
1071-
int4_weight: bool = False,
10721071
quantize_activation: bool = False,
10731072
batch_block_size: int | None = None,
10741073
out_block_size: int | None = None,
@@ -1078,7 +1077,7 @@ def quantized_matmul(
10781077
from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import quantized_matmul
10791078
return xb.call_jax(
10801079
quantized_matmul,
1081-
(x, w, scalar, zero_point, block_size, int4_weight, quantize_activation),
1080+
(x, w, scalar, zero_point, block_size, quantize_activation),
10821081
{"batch_block_size": batch_block_size, "out_block_size": out_block_size, "in_block_size": in_block_size, "vmem_limit_bytes": vmem_limit_bytes}
10831082
)
10841083

0 commit comments

Comments
 (0)