Skip to content

Commit

Permalink
Fix tests + JIT
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 1, 2024
1 parent b0be412 commit e97b733
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None:
@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
y = torch.nn.functional.pad(
x, padding(x, self.group_size, self.group_dim), mode='constant', value=0)
x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.)
y = y.view(self.expanded_groupwise_shape)
return y

Expand All @@ -186,7 +186,7 @@ def forward(self, x):
tensor_shape_list = list(tensor_shape)
pad = padding(x, self.group_size, self.group_dim)

x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0.)

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
from torch.nn import Sequential
Expand Down Expand Up @@ -105,7 +105,7 @@ def float_internal_scale(


@brevitas.jit.ignore
def padding(x, group_size, group_dim):
def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]:
# Given a tensor X, compute the padding aloing group_dim so that groupwise shaping is possible
padding = [0, 0] * len(x.shape)
size = x.shape
Expand Down
12 changes: 6 additions & 6 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def build_case_model(
weight_quant_name, weight_quantizer = weight_quantizer
bias_quant_name, bias_quantizer = bias_quantizer
io_quant_name, io_quantizer = io_quantizer
print(io_quant_name)

if ((io_quantizer is None and not input_quantized) or
'float' in io_quant_name) and weight_quant_name in A2Q_WBIOL_WEIGHT_QUANTIZER:
pytest.skip(
Expand All @@ -134,8 +134,6 @@ def build_case_model(
'mx' not in io_quant_name) or ('mx' not in weight_quant_name and 'mx' in io_quant_name):
pytest.skip("MX requires input and weights quantization to be aligned")
elif weight_quantizer == MXInt8Weight:
if config.JIT_ENABLED:
pytest.skip("Dynamic act quant is not compatible with JIT")
if bias_quant_name != 'quant_internal':
pytest.skip("MX quant does not support external scaled bias")
elif weight_quantizer == Fp8e4m3WeightPerTensorFloat or io_quantizer == Fp8e4m3ActPerTensorFloat:
Expand Down Expand Up @@ -640,16 +638,18 @@ def case_mha(

# Change the case_id based on current value of Parameters
set_case_id(request.node.callspec.id, case_mha)
k, weight_quantizer = weight_quantizer
weight_quant_name, weight_quantizer = weight_quantizer
_, bias_quantizer = bias_quantizer
_, io_quantizer = io_quantizer

if io_quantizer is None and k in A2Q_WBIOL_WEIGHT_QUANTIZER:
if io_quantizer is None and weight_quant_name in A2Q_WBIOL_WEIGHT_QUANTIZER:
# Can't rely on a QuantTensor input for quant_mha at this point
pytest.skip(
"A2Q uses an input-aware decoupled weight proxy that requires a quantized input tensor."
)

# TODO: restore compatibility
if ('mx' in weight_quant_name or 'float' in weight_quant_name):
pytest.skip("MX/Float quant not supported for MHA")
# BatchQuant1d works over 3d input but not 2d, so we have a separate quantizer for out_proj
if isinstance(io_quantizer, tuple):
io_quantizer, out_proj_io_quantizer = io_quantizer
Expand Down
5 changes: 0 additions & 5 deletions tests/brevitas/nn/test_nn_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,6 @@ def test_quant_mha(model_input, current_cases):
args = case_id.split('-')[1:] # Exclude first argument
kwargs = parse_args(args)

# TODO: restore compatibility
skipped_quant = ['quant_mx', 'quant_float']
if kwargs['io_quant'] in skipped_quant or kwargs['weight_quant'] in skipped_quant:
pytest.skip("MX and Float quant not supported for MHA")

is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized']
if (not is_input_quanttensor or
kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external':
Expand Down

0 comments on commit e97b733

Please sign in to comment.