Skip to content

Commit b2fb664

Browse files
authored
Add int8 dynamic activation + int8 weight only test to TensorParallel (#1657)
1 parent 7e54629 commit b2fb664

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

test/dtypes/test_affine_quantized_tensor_parallel.py

+13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
float8_dynamic_activation_float8_weight,
1414
float8_weight_only,
1515
int4_weight_only,
16+
int8_dynamic_activation_int8_weight,
1617
int8_weight_only,
1718
)
1819
from torchao.quantization.observer import PerRow, PerTensor
@@ -166,9 +167,21 @@ def test_tp_gemlite(self, dtype):
166167
return self._test_tp(dtype)
167168

168169

170+
class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
171+
QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight)
172+
COMMON_DTYPES = [torch.bfloat16]
173+
174+
@common_utils.parametrize("dtype", COMMON_DTYPES)
175+
@with_comms
176+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
177+
def test_tp(self, dtype):
178+
return self._test_tp(dtype)
179+
180+
169181
common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
170182
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)
171183
common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel)
184+
common_utils.instantiate_parametrized_tests(TestInt8dqAffineQuantizedTensorParallel)
172185

173186
# Run only on H100
174187
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):

0 commit comments

Comments
 (0)