Skip to content

Commit

Permalink
Feat (core): use runtime parameter for scale (#1037)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Nov 7, 2024
1 parent d7d88c6 commit 84cdfc3
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 13 deletions.
6 changes: 4 additions & 2 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ def __init__(

@brevitas.jit.script_method
def forward(
self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.parameter_list_stats()
self,
x: Optional[torch.Tensor],
threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.parameter_list_stats(x)
if threshold is None:
threshold = torch.ones(1).type_as(stats)
return self.stats_scaling_impl(stats, threshold)
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,17 @@ def __init__(
self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(ignored)
threshold = torch.ones(1).type_as(x)
if self.init_done:
threshold = self.stats_scaling_impl.restrict_clamp_threshold(
self.restrict_threshold_pre(threshold))
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
value = value / threshold
return value
else:
stats = self.parameter_list_stats()
stats = self.parameter_list_stats(x)
# workaround to avoid find_ununsed_parameter=True in DDP
stats = stats + 0. * self.value
if self.local_loss_mode:
Expand Down
15 changes: 11 additions & 4 deletions src/brevitas/core/stats/stats_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.core.utils import inplace_tensor_mul

from .view_wrapper import _ViewCatParameterWrapper
from .view_wrapper import _ViewParameter
from .view_wrapper import _ViewParameterWrapper

DEFAULT_MOMENTUM = 0.1
Expand Down Expand Up @@ -96,8 +97,12 @@ def __init__(
super(_ParameterListStats, self).__init__()

self.stats_input_concat_dim = stats_input_concat_dim
self.first_tracked_param = _ViewParameterWrapper(
tracked_parameter_list[0], stats_input_view_shape_impl)
if len(tracked_parameter_list) >= 1:
self.first_tracked_param = _ViewParameterWrapper(
tracked_parameter_list[0], stats_input_view_shape_impl)
else:
self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl)

if len(tracked_parameter_list) > 1:
extra_list = [
_ViewCatParameterWrapper(
Expand All @@ -109,10 +114,12 @@ def __init__(
self.stats = _Stats(stats_impl, stats_output_shape)

@brevitas.jit.script_method
def forward(self) -> torch.Tensor:
stats_input = self.first_tracked_param()
def forward(self, x: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.extra_tracked_params_list is not None:
stats_input = self.first_tracked_param(None)
for extra_tracked_param in self.extra_tracked_params_list:
stats_input = extra_tracked_param(stats_input)
else:
stats_input = self.first_tracked_param(x)
out = self.stats(stats_input)
return out
21 changes: 19 additions & 2 deletions src/brevitas/core/stats/view_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Module
Expand All @@ -19,8 +21,12 @@ def __init__(self, parameter: Parameter, view_shape_impl: Module) -> None:
self.view_shape_impl = view_shape_impl

@brevitas.jit.script_method
def forward(self) -> Tensor:
return self.view_shape_impl(self.parameter)
def forward(self, x: Optional[Tensor]) -> Tensor:
if x is not None:
parameter = x
else:
parameter = self.parameter
return self.view_shape_impl(parameter)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
Expand All @@ -39,6 +45,17 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
return output_dict


class _ViewParameter(brevitas.jit.ScriptModule):

def __init__(self, view_shape_impl: Module) -> None:
super(_ViewParameter, self).__init__()
self.view_shape_impl = view_shape_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
return self.view_shape_impl(x)


class _ViewCatParameterWrapper(brevitas.jit.ScriptModule):
__constants__ = ['cat_dim']

Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor:
stats = self.parameter_list_stats()
stats = self.parameter_list_stats(x)
return self.scale_shift_zero_point(-stats, scale, bit_width)


Expand Down Expand Up @@ -266,7 +266,7 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor:
value = self.scale_shift_zero_point(value, scale, bit_width)
return value
else:
stats = self.parameter_list_stats()
stats = self.parameter_list_stats(x)
# workaround to avoid find_ununsed_parameter=True in DDP
stats = stats + 0. * self.value
if self.local_loss_mode:
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ def main(args):

model = offload_model(model)

model(**calibration_loader[0])

if args.act_calibration:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
Expand Down

0 comments on commit 84cdfc3

Please sign in to comment.