@@ -890,11 +890,10 @@ def _test_quantized_matmul(
890
890
in_block_size = None ,
891
891
atol = 1.5 ,
892
892
n_bits = 8 ,
893
- ):
893
+ ):
894
894
x = torch .randn ((bs , n_input_features ), dtype = dtype )
895
895
w = torch .randn ((n_output_features , n_input_features ), dtype = dtype )
896
- min_val , max_val = torch .aminmax (
897
- w , dim = 1 ) # min_val, max_val [out_dim]
896
+ min_val , max_val = torch .aminmax (w , dim = 1 ) # min_val, max_val [out_dim]
898
897
int_min = - 2 ** (n_bits - 1 )
899
898
int_max = 2 ** (n_bits - 1 ) - 1
900
899
scalar , zero_point = determine_qparams (
@@ -913,21 +912,30 @@ def _test_quantized_matmul(
913
912
x_copy = x .clone ()
914
913
w_copy = w .clone ()
915
914
expected = F .linear (x_copy , w_copy )
916
-
915
+
917
916
x_xla = x .to ("xla" )
918
917
w_int_xla = w_int .to ("xla" )
919
918
scalar_xla = scalar .to ("xla" )
920
919
if use_dynamo :
921
- def quantized_matmul_wrapper (x , w_int , scalar ):
922
- return torch .ops .xla .quantized_matmul (
923
- x , w_int , scalar , quantize_activation = quantize_activation , batch_block_size = batch_block_size ,
924
- out_block_size = out_block_size , in_block_size = in_block_size )
925
920
926
- quantized_matmul = torch .compile (quantized_matmul_wrapper , backend = "openxla" )
921
+ def quantized_matmul_wrapper (x , w_int , scalar , quantize_activation ,
922
+ batch_block_size , out_block_size ,
923
+ in_block_size ):
924
+ return torch .ops .xla .quantized_matmul (
925
+ x ,
926
+ w_int ,
927
+ scalar ,
928
+ quantize_activation = quantize_activation ,
929
+ batch_block_size = batch_block_size ,
930
+ out_block_size = out_block_size ,
931
+ in_block_size = in_block_size )
932
+
933
+ quantized_matmul = torch .compile (
934
+ quantized_matmul_wrapper , backend = "openxla" )
927
935
else :
928
936
from torch_xla .experimental .custom_kernel import quantized_matmul
929
937
quantized_matmul = quantized_matmul
930
-
938
+
931
939
actual = quantized_matmul (
932
940
x_xla ,
933
941
w_int_xla ,
@@ -936,68 +944,43 @@ def quantized_matmul_wrapper(x, w_int, scalar):
936
944
batch_block_size = batch_block_size ,
937
945
out_block_size = out_block_size ,
938
946
in_block_size = in_block_size ).cpu ()
939
-
947
+
940
948
self .assertEqual (actual .shape , expected .shape )
941
949
self .assertEqual (actual .dtype , expected .dtype )
942
- self .assertTrue (
943
- torch .allclose (
944
- actual , expected , atol = atol ))
950
+ self .assertTrue (torch .allclose (actual , expected , atol = atol ))
945
951
946
-
947
- @parameterized .product (
948
- seq_lens = [[(1 , 1328 ), (5 , 18 ), (500 , 563 )]],
949
- num_heads = [(32 , 8 ), (8 , 1 )],
950
- dtype = [(torch .bfloat16 , torch .bfloat16 ),
951
- (torch .bfloat16 , torch .float8_e5m2 )],
952
- sm_scale = [1.0 , 0.5 ],
953
- sliding_window = [None , 128 ],
954
- soft_cap = [None , 10.0 ],
955
- pad_tokens_and_seqs = [False , True ])
956
- @unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
957
- "This test only works on TPUv4+." )
958
- def test_quantized_matmul_with_dynamo (
959
- self ,
960
- seq_lens ,
961
- num_heads ,
962
- dtype ,
963
- sm_scale ,
964
- sliding_window ,
965
- soft_cap ,
966
- pad_tokens_and_seqs ,
967
- ):
968
- ...
969
-
970
- # @parameterized.product(
971
- # dtype=[torch.bfloat16],
972
- # bs=[128],
973
- # n_input_features=[128],
974
- # n_output_features=[128],
975
- # quantize_activation=[True],
976
- # # block_sizes=[(None, None, None), (128, 128, 128)],
977
- # kernel_block_sizes=[(128, 128, 128)],
978
- # )
979
952
@parameterized .product (
980
953
dtype = [torch .bfloat16 , torch .float32 ],
981
- bs = [128 , 256 ],
982
- n_input_features = [128 , 256 ],
983
- n_output_features = [128 , 256 ],
954
+ bs = [256 , 512 ],
955
+ n_input_features = [256 , 512 ],
956
+ n_output_features = [256 , 512 ],
984
957
quantize_activation = [True ],
985
- # block_sizes =[(None, None, None), (128, 128, 128 )],
986
- kernel_block_sizes = [( 128 , 128 , 128 ) ],
958
+ kernel_block_sizes = [(None , None , None ), (256 , 256 , 256 )],
959
+ use_dynamo = [ True , False ],
987
960
)
988
961
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 5 ,
989
962
"This test only works on TPUv5+." )
990
- def test_quantized_matmul_wrapper_without_dynamo (
963
+ def test_quantized_matmul_wrapper (
991
964
self ,
992
965
dtype ,
993
966
bs ,
994
967
n_input_features ,
995
968
n_output_features ,
996
969
quantize_activation ,
997
970
kernel_block_sizes ,
971
+ use_dynamo ,
998
972
):
999
973
batch_block_size , out_block_size , in_block_size = kernel_block_sizes
1000
- self ._test_quantized_matmul (dtype , bs , n_input_features , n_output_features , quantize_activation , use_dynamo = False , batch_block_size = batch_block_size , out_block_size = out_block_size , in_block_size = in_block_size )
974
+ self ._test_quantized_matmul (
975
+ dtype ,
976
+ bs ,
977
+ n_input_features ,
978
+ n_output_features ,
979
+ quantize_activation ,
980
+ use_dynamo = use_dynamo ,
981
+ batch_block_size = batch_block_size ,
982
+ out_block_size = out_block_size ,
983
+ in_block_size = in_block_size )
1001
984
1002
985
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
1003
986
"This test only works on TPUv4+." )
0 commit comments