Skip to content

Commit

Permalink
Fix local loss tests + JIT
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 11, 2024
1 parent 1d5d365 commit 858bb8c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def tests_brevitas_cpu(session, pytorch, jit_status):
@nox.parametrize("pytorch", PYTORCH_VERSIONS, ids=PYTORCH_IDS)
@nox.parametrize("jit_status", JIT_STATUSES, ids=JIT_IDS)
def tests_brevitas_examples_cpu(session, pytorch, jit_status):
session.env['PYTORCH_JIT'] = '{}'.format(int(jit_status == 'jit_enabled'))
session.env['BREVITAS_JIT'] = '{}'.format(int(jit_status == 'jit_enabled'))
install_pytorch(pytorch, session)
install_torchvision(pytorch, session) # For CV eval scripts
session.install('--upgrade', '.[test, tts, stt, vision]')
Expand Down
2 changes: 2 additions & 0 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def run_test_models_run_args(args, model_with_ppl):
def toggle_run_args(default_run_args, request):
args = default_run_args
args.update(**request.param)
if args.weight_param_method == 'hqo' and config.JIT_ENABLED:
pytest.skip("Local loss mode requires JIT to be disabled")
yield args


Expand Down
3 changes: 3 additions & 0 deletions tests/brevitas_examples/test_quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from brevitas.nn import QuantReLU
from brevitas.quant_tensor import QuantTensor
from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model
from tests.marker import jit_disabled_for_local_loss
from tests.marker import jit_disabled_for_mock

# CONSTANTS
IMAGE_DIM = 16
Expand Down Expand Up @@ -568,6 +570,7 @@ def test_layerwise_percentile_for_calibration(simple_model, act_quant_percentile


@pytest.mark.parametrize("quant_granularity", ["per_tensor", "per_channel"])
@jit_disabled_for_local_loss()
def test_layerwise_param_method_mse(simple_model, quant_granularity):
"""
We test layerwise quantization, with the weight and activation quantization `mse` parameter
Expand Down
10 changes: 10 additions & 0 deletions tests/marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ def skip_wrapper(f):
return skip_wrapper


def jit_disabled_for_local_loss():
skip = config.JIT_ENABLED

def skip_wrapper(f):
return pytest.mark.skipif(
skip, reason=f'Local loss functions (e.g., MSE) require JIT to be disabled')(f)

return skip_wrapper


def jit_disabled_for_dynamic_quant_act():
skip = config.JIT_ENABLED

Expand Down

0 comments on commit 858bb8c

Please sign in to comment.