Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HQO for scale/zero point #937

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def tests_brevitas_cpu(session, pytorch, jit_status):
@nox.parametrize("pytorch", PYTORCH_VERSIONS, ids=PYTORCH_IDS)
@nox.parametrize("jit_status", JIT_STATUSES, ids=JIT_IDS)
def tests_brevitas_examples_cpu(session, pytorch, jit_status):
session.env['PYTORCH_JIT'] = '{}'.format(int(jit_status == 'jit_enabled'))
session.env['BREVITAS_JIT'] = '{}'.format(int(jit_status == 'jit_enabled'))
install_pytorch(pytorch, session)
install_torchvision(pytorch, session) # For CV eval scripts
session.install('--upgrade', '.[test, tts, stt, vision]')
Expand Down
222 changes: 222 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import brevitas
from brevitas import config
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.function_wrapper.ops_ste import ScalarClampMinSte
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
from brevitas.quant_tensor import _unpack_quant_tensor
Expand Down Expand Up @@ -544,3 +546,223 @@ def forward(self, x):
x = self.input_view_shape_impl(x)
self.internal_candidate = self.mse_init_op(x).detach()
return self.internal_candidate


class HalfQuadraticOptimizerScale(torch.nn.Module):
# References:
# https://mobiusml.github.io/hqq_blog/
# https://github.com/mobiusml/hqq?tab=readme-ov-file

def __init__(
self,
proxy_module,
hqo_init_op_scale,
keepdim: bool,
inner_stats_input_view_shape_impl: torch.nn.Module,
scaling_min_val: Optional[float] = None,
stats_reduce_dim: Optional[int] = None,
int_scaling_impl=None,
bit_width_impl=None,
hqo_beta_scale: float = 1e5,
hqo_kappa_scale: float = 1.01,
hqo_lp_norm_scale: float = .7,
hqo_iters_scale: int = 1000):
super(HalfQuadraticOptimizerScale, self).__init__()
self.hqo_init_op = hqo_init_op_scale
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.internal_candidate = None
self.hqo_iters = hqo_iters_scale
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False

self.beta = hqo_beta_scale
self.kappa = hqo_kappa_scale
self.lp_norm = hqo_lp_norm_scale

self.int_scaling_impl = int_scaling_impl
self.msb_clamp_bit_width_impl = bit_width_impl
if scaling_min_val is not None and scaling_min_val != 0:
self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
else:
self.clamp_min_ste = Identity()
self.keepdim = keepdim

def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
candidate = xl
best_candidate = candidate
beta = self.beta
with torch.no_grad():
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
loss = torch.abs(quant_tensor.value - x).mean()

best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm)
zero_point = quant_tensor.zero_point
num = self.input_view_shape_impl(x - W_e).detach()
den = self.input_view_shape_impl(
torch.round(quant_tensor.value / quant_tensor.scale) - zero_point).detach()
mask = (num != 0.) & (den != 0.)
if self.stats_reduce_dim is None:
candidate = masked_median(num / den, mask)
else:
candidate = masked_median(
num / den, mask, dim=self.stats_reduce_dim, keepdim=self.keepdim)
candidate = candidate.type_as(self.internal_candidate)
candidate = self.clamp_min_ste(candidate)
bit_width = self.msb_clamp_bit_width_impl()
int_threshold = self.int_scaling_impl(bit_width)
Copy link
Collaborator Author

@Giuseppe5 Giuseppe5 Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Int specific implementation, maybe it should be generalized to also cover minifloat case

Copy link
Collaborator

@nickfraser nickfraser Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it should be generalized to also cover minifloat case

This looks like it hasn't been resolved to me. Is that true?

candidate = candidate * int_threshold
candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)]
candidate[torch.isinf(candidate)] = self.internal_candidate[torch.isinf(candidate)]
beta *= self.kappa
return best_candidate

def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.hqo_init_op(x_view).detach()
best_candidate = self.parameter_search(init, x_view)

# Save for evaluation by other modules (e.g. zp) invoking local loss mode
self.internal_candidate = best_candidate.detach()
torch.cuda.empty_cache()
return best_candidate

def forward(self, x):
if not self.local_loss_mode:
with torch.no_grad():
return self.optimize(x)
else:
# This is invoked for the zero-point whenever scale is being optimized first
if self.internal_candidate is None:
x = self.input_view_shape_impl(x)
self.internal_candidate = self.hqo_init_op(x).detach()
return self.internal_candidate


class HalfQuadraticOptimizerZeroPoint(torch.nn.Module):
# References:
# https://mobiusml.github.io/hqq_blog/
# https://github.com/mobiusml/hqq?tab=readme-ov-file

def __init__(
self,
proxy_module,
keepdim: bool,
hqo_init_op_zp: torch.nn.Module,
inner_stats_input_view_shape_impl: torch.nn.Module,
stats_reduce_dim: Optional[int] = None,
hqo_beta_zp: float = 1e0,
hqo_kappa_zp: float = 1.01,
hqo_lp_norm_zp: float = .5,
hqo_iters_zp: int = 1000):
super(HalfQuadraticOptimizerZeroPoint, self).__init__()
self.hqo_init_op_zp = hqo_init_op_zp
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.internal_candidate = None
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False
self.beta = hqo_beta_zp
self.kappa = hqo_kappa_zp
self.lp_norm = hqo_lp_norm_zp
self.hqo_iters = hqo_iters_zp
self.keepdim = keepdim

def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
candidate = xl
best_candidate = candidate
with torch.no_grad():
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
qt_value = self.input_view_shape_impl(quant_tensor.value)
qt_scale = self.input_view_shape_impl(quant_tensor.scale)
qt_int = self.input_view_shape_impl(quant_tensor.int())
loss = torch.abs(qt_value - x).mean()
best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm)

val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale)

if self.stats_reduce_dim is None:
candidate = torch.mean(val)
else:
candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=self.keepdim)
self.beta *= self.kappa
return best_candidate

def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.hqo_init_op_zp(x_view).detach()

best_candidate = self.parameter_search(init, x)

# Save for evaluation by other modules (e.g. zp) invoking local loss mode
self.internal_candidate = best_candidate.detach()
torch.cuda.empty_cache()
return best_candidate

def forward(self, x):
if not self.local_loss_mode:
with torch.no_grad():
return self.optimize(x)
else:
# This is invoked for the zero-point whenever scale is being optimized first
if self.internal_candidate is None:
x = self.input_view_shape_impl(x)
self.internal_candidate = self.hqo_init_op_zp(x).detach()
return self.internal_candidate


def masked_median(x, mask, dim=None, keepdim=False):
"""Compute the median of tensor x along dim, ignoring values where mask is False.
x and mask need to be broadcastable.

Args:
x (Tensor): Tensor to compute median of.
mask (BoolTensor): Same shape as x with True where x is valid and False
where x should be masked. Mask should not be all False in any column of
dimension dim to avoid NaNs from zero division.
dim (int, optional): Dimension to take median of. Defaults to 0.

Returns:
Tensor: Same shape as x, except dimension dim reduced.
"""
# uncomment this assert for safety but might impact performance
# assert (
# mask.sum(dim=dim).ne(0).all()
# ), "mask should not be all False in any column, causes zero division"
x_nan = x.float().masked_fill(~mask, float("nan"))
if dim is None:
x_median = x_nan.nanmedian()
else:
x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim)
return x_median


# Shrinking operator
def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor:
if lp_norm == 1:
return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
else:
return torch.sign(x) * torch.nn.functional.relu(
torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1))
2 changes: 1 addition & 1 deletion src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(ParameterFromStatsFromParameterZeroPoint, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
# Avoid saving the init value
if not self.init_done:
if not self.init_done and not config._FULL_STATE_DICT:
del output_dict[prefix + 'value']
return output_dict

Expand Down
43 changes: 41 additions & 2 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from brevitas.core.stats import MSE
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerScale
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint
from brevitas.core.utils import SingleArgStatelessBuffer
from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
Expand Down Expand Up @@ -458,7 +460,7 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase):
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
type = (this << 1).type
dtype = (this << 1).dtype


class MSEZeroPointSubInjector(MSESubInjectorBase):
Expand All @@ -470,7 +472,7 @@ class MSEZeroPointSubInjector(MSESubInjectorBase):
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
type = (this << 1).type
dtype = (this << 1).dtype


class MSEAsymmetricScale(ExtendedInjector):
Expand Down Expand Up @@ -520,3 +522,40 @@ class MSEWeightZeroPoint(MSEZeroPoint):

class MSEActZeroPoint(MSEZeroPoint):
zero_point_impl = ParameterFromRuntimeZeroPoint


class HQOZeroPoint(ExtendedInjector):

hqo_init_op_zp = NegativeMinOrZero
inner_stats_input_view_shape_impl = this.zero_point_stats_input_view_shape_impl
stats_impl_zp = HalfQuadraticOptimizerZeroPoint

@value
def zero_point_stats_impl():
return this.stats_impl_zp


class HQOScale(ExtendedInjector):
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
inner_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
stats_impl_scale = HalfQuadraticOptimizerScale

@value
def scaling_stats_impl():
return this.stats_impl_scale


class HQOAsymmetricScale(HQOScale):
hqo_init_op_scale = AbsMinMax


class HQOSymmetricScale(HQOScale):
hqo_init_op_scale = AbsMax


class HQOActZeroPoint(HQOZeroPoint):
zero_point_impl = ParameterFromRuntimeZeroPoint


class HQOWeightZeroPoint(HQOZeroPoint):
zero_point_impl = ParameterFromStatsFromParameterZeroPoint
25 changes: 25 additions & 0 deletions src/brevitas/quant/scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from brevitas.core.function_wrapper import TensorClamp
from brevitas.quant.base import *
from brevitas.quant.base import HQOSymmetricScale
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.bias import BiasQuantSolver
from brevitas.quant.solver.trunc import TruncQuantSolver
Expand Down Expand Up @@ -443,3 +444,27 @@ class Int8AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareZeroCenterWeight
>>> conv.quant_weight()
"""
bit_width = 8


class Int8WeightPerTensorFloatHQO(HQOSymmetricScale, Int8WeightPerTensorFloat):
"""
8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed
from HQO local loss.

Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerTensorFloatHQO)
"""
pass


class Int8WeightPerChannelFloatHQO(HQOSymmetricScale, Int8WeightPerChannelFloat):
"""
8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed
from HQO local loss.

Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerChannelFloatHQO)
"""
pass
Loading
Loading