Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative approach to support torch.compile #1006

Merged
merged 25 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ def tests_brevitas_end_to_end(session, pytorch):
install_pytorch(pytorch, session)
install_torchvision(pytorch, session)
session.install('--upgrade', '-e', '.[test, ort_integration]')
session.run('pytest', '-v', 'tests/brevitas_end_to_end')
session.run('pytest', '-n', 'logical', '-v', 'tests/brevitas_end_to_end')
6 changes: 6 additions & 0 deletions src/brevitas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
else:
torch_version = version.parse(torch.__version__)

try:
# Attempt _dynamo import
is_dynamo_compiling = torch._dynamo.is_compiling
except:
is_dynamo_compiling = lambda: False

try:
__version__ = get_distribution(__name__).version
except DistributionNotFound:
Expand Down
49 changes: 30 additions & 19 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,40 +113,51 @@ def __init__(
else:
self.max_available_float = None

def inf_nan_clamp(self, x, inf_mask, p_max_val_mask, n_max_val_mask):

# if non-saturating, we need to map values greater than max_val to nan or inf
if self.inf_values is not None:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = torch.tensor(float('inf'))
x[n_max_val_mask] = torch.tensor(float('-inf'))
elif self.nan_values is not None:
# no inf values, so we need to map them to NaN
full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask)
x[full_max_val_mask] = torch.tensor(float('nan'))

# we also map the inf values to NaN in this case
x[inf_mask] = torch.tensor(float('nan'))
else:
raise RuntimeError(
"Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified"
)
return x

def saturating_clamp(self, x, max_value, min_value):
return self.tensor_clamp_impl(x, min_val=min_value, max_val=max_value)

nickfraser marked this conversation as resolved.
Show resolved Hide resolved
@brevitas.jit.script_method
def forward(
self,
x: Tensor,
exponent_bit_width: Tensor,
mantissa_bit_width: Tensor,
exponent_bias: Tensor):
inf_mask = x.isinf()

max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
min_value = torch.tensor(0.) if not self.signed else -max_value

# Compute masks
inf_mask = x.isinf()
p_max_val_mask = x > max_value
n_max_val_mask = -x > max_value
min_float = torch.tensor(0.) if not self.signed else -max_value

# first clamp everything to +- max_value, basically the saturating case
x = self.tensor_clamp_impl(x, min_val=min_float, max_val=max_value)
x = self.saturating_clamp(x, max_value, min_value)

if not self.saturating:
# if non-saturating, we need to map values greater than max_val to nan or inf
if self.inf_values is not None:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = torch.tensor(float('inf'))
x[n_max_val_mask] = torch.tensor(float('-inf'))
elif self.nan_values is not None:
# no inf values, so we need to map them to NaN
full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask)
x[full_max_val_mask] = torch.tensor(float('nan'))

# we also map the inf values to NaN in this case
x[inf_mask] = torch.tensor(float('nan'))
else:
raise RuntimeError(
"Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified"
)
x = self.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask)

return x, self.saturating, self.inf_values, self.nan_values
4 changes: 2 additions & 2 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def prepare_for_export(self, module):
self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width()
self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width()
self.symbolic_kwargs['exponent_bias'] = module.exponent_bias()
self.symbolic_kwargs['saturating'] = module.saturating()
self.symbolic_kwargs['saturating'] = module.is_saturating()
self.symbolic_kwargs['inf_values'] = module.inf_values()
self.symbolic_kwargs['nan_values'] = module.nan_values()

Expand Down Expand Up @@ -659,7 +659,7 @@ def prepare_for_export(self, module):
'exponent_bit_width': module.exponent_bit_width(),
'mantissa_bit_width': module.mantissa_bit_width(),
'exponent_bias': module.exponent_bias(),
'saturating': module.saturating(),
'saturating': module.is_saturating(),
'inf_values': module.inf_values(),
'nan_values': module.nan_values()}

Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/export/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from .manager import InferenceManager
from .manager import quant_inference_mode
153 changes: 153 additions & 0 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from abc import abstractmethod
from typing import Tuple

import torch

from brevitas.function.ops import max_float
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.utils.torch_utils import float_internal_scale


class InferenceHandler(torch.nn.Module, ABC):

def attach_debug_info(self, module):
pass

@abstractmethod
def prepare_for_export(self, module):
pass

@abstractmethod
def quantize(self, x):
pass

@abstractmethod
def dequantize(self, x):
pass


class IntInferencetHandler(InferenceHandler):
handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector)

def attach_debug_info(self, module):
pass

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.scale = module.scale()
self.zero_point = module.zero_point().to(self.scale.device)
self.bit_width = module.bit_width()
self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width)
self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width)

def quantize(self, x):
return torch.clamp(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like these won't work with Groupwise quantization, correct? So inference_mode + MX won't work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I forgot to add the export handler for MX INT and MX Float

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Postponed to another update

torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp)

def dequantize(self, x):
return (x - self.zero_point) * self.scale

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width


class IntWeightInferencetHandler(IntInferencetHandler):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.cached_weight = None
super().prepare_for_export(module)
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.value

def forward(self, x) -> Tuple[torch.Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
return x, self.scale, self.zero_point, self.bit_width


class FloatInferencetHandler(InferenceHandler):
handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector)

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.scale = module.scale()
self.zero_point = module.zero_point().to(self.scale.device)
self.exponent_bit_width = module.exponent_bit_width()
self.mantissa_bit_width = module.mantissa_bit_width()
self.exponent_bias = module.exponent_bias()
self.saturating = module.is_saturating()
self.inf_values = module.inf_values()
self.nan_values = module.nan_values()
self.eps = torch.finfo(self.scale.dtype).tiny
if hasattr(module.tensor_quant, 'float_to_int_impl'):
self.float_to_int_impl = module.tensor_quant.float_to_int_impl
self.float_clamp_impl = module.tensor_quant.float_clamp_impl
elif hasattr(module, 'fused_activation_quant_proxy'):
self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.float_to_int_impl
self.float_clamp_impl = module.fused_activation_quant_proxy.tensor_quant.float_clamp_impl

self.max_clamp = max_float(
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.min_clamp = -self.max_clamp
self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width
self.max_value = max_float(
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value

def quantize(self, x):
# Compute masks
inf_mask = x.isinf()
p_max_val_mask = x > self.max_value
n_max_val_mask = -x > self.max_value

# Quantize
x = x / self.scale
internal_scale = float_internal_scale(
x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps)
x = internal_scale * self.float_to_int_impl(x / internal_scale)

# Clamp
x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value)
if not self.saturating:
x = self.float_clamp_impl.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask)

return x

def dequantize(self, x):
return (x - self.zero_point) * self.scale

def forward(self, x) -> Tuple[torch.Tensor]:
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class FloatWeightInferencetHandler(FloatInferencetHandler):
handled_layer = WeightFloatQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.cached_weight = None
super().prepare_for_export(module)
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.value

def forward(self, x) -> Tuple[torch.Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
106 changes: 106 additions & 0 deletions src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from torch.nn import Module
import torch.nn as nn

from brevitas.export.inference.handler import FloatInferencetHandler
from brevitas.export.inference.handler import FloatWeightInferencetHandler
from brevitas.export.inference.handler import IntInferencetHandler
from brevitas.export.inference.handler import IntWeightInferencetHandler
from brevitas.export.manager import _set_proxy_export_handler
from brevitas.export.manager import _set_proxy_export_mode
from brevitas.export.manager import _set_recurrent_layer_export_handler
from brevitas.export.manager import _set_recurrent_layer_export_mode
from brevitas.export.manager import BaseManager
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import restore_return_quant_tensor


def _override_caching_mode(m: nn.Module, attr: str, enabled: bool, metadata_only: bool = True):
cache_var = 'cache_inference_quant_' + attr
cache_var_metadata_only = cache_var + '_metadata_only'
if hasattr(m, cache_var):
setattr(m, cache_var, enabled)
setattr(m, cache_var_metadata_only, metadata_only)


def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
_override_caching_mode(m, 'bias', enabled, metadata_only)


def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
_override_caching_mode(m, 'act', enabled, metadata_only)


def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False):
_override_caching_mode(m, 'weight', enabled, metadata_only)


class quant_inference_mode:

def __init__(self, model, cache_quant_weight=False, enabled=True):
self.model = model
self.enabled = enabled
self.cache_quant_weight = cache_quant_weight
self.export_manager = InferenceManager
self.hook_list = []
self.return_quant_tensor_state = dict()

def __enter__(self):
if self.enabled:
# Register the hook and store it in the list so that it can be removed by the hook itself when called
handle = self.model.register_forward_hook(self.hook)
self.hook_list.append(handle)

# Enable bias for everything. Optionally, store the fully fake-quantized weights
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True))
self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True))
self.model.apply(
lambda m: _override_weight_caching_mode(
m, enabled=True, metadata_only=not self.cache_quant_weight))

def __exit__(self, type, value, traceback):
# Disable all caching
# deactivate export mode
# restore return quant tensor
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False))
self.model.apply(
lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False))
if self.cache_quant_weight:
self.model.apply(
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False))
InferenceManager.set_export_mode(self.model, enabled=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

def hook(self, module, inp, out):
# After one forward pass with caching enabled, we can:
# - Set the model in export mode
# - Attach export handlers
# - Disable return quant tensor since all quant metadata is cached
assert len(self.hook_list) == 1
self.hook_list[0].remove()
self.model.apply(InferenceManager.set_export_handler)
InferenceManager.set_export_mode(self.model, enabled=True)
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)


# Inheritance from BaseManager is not techincally needed
class InferenceManager(BaseManager):
handlers = [
IntInferencetHandler,
FloatInferencetHandler,
IntWeightInferencetHandler,
FloatWeightInferencetHandler]

@classmethod
def set_export_mode(cls, model: Module, enabled: bool):
_set_proxy_export_mode(model, enabled)
_set_recurrent_layer_export_mode(model, enabled)

@classmethod
def set_export_handler(cls, module: Module):
_set_proxy_export_handler(cls, module)
_set_recurrent_layer_export_handler(cls, module)
8 changes: 6 additions & 2 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,15 @@ def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs):
@classmethod
def handler_from_module(cls, module: Module, no_inheritance=False):
for handler in cls.handlers:
if not isinstance(handler.handled_layer, tuple):
handled_classes = (handler.handled_layer,)
else:
handled_classes = handler.handled_layer
if no_inheritance:
if type(module) == handler.handled_layer:
if type(module) in handled_classes:
return handler
else:
if isinstance(module, handler.handled_layer):
if any([isinstance(module, handler) for handler in handled_classes]):
return handler
return None

Expand Down
Loading
Loading