diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 09f891ed7..a94f8cd6e 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -60,8 +60,10 @@ def __init__( @brevitas.jit.script_method def forward( - self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: - stats = self.parameter_list_stats() + self, + x: Optional[torch.Tensor], + threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + stats = self.parameter_list_stats(x) if threshold is None: threshold = torch.ones(1).type_as(stats) return self.stats_scaling_impl(stats, threshold) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 13ead5afc..da13e84ff 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -241,9 +241,9 @@ def __init__( self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: - threshold = torch.ones(1).type_as(ignored) + threshold = torch.ones(1).type_as(x) if self.init_done: threshold = self.stats_scaling_impl.restrict_clamp_threshold( self.restrict_threshold_pre(threshold)) @@ -251,7 +251,7 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor value = value / threshold return value else: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index df3cec952..49bf62a82 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -13,6 +13,7 @@ from brevitas.core.utils import inplace_tensor_mul from .view_wrapper import _ViewCatParameterWrapper +from .view_wrapper import _ViewParameter from .view_wrapper import _ViewParameterWrapper DEFAULT_MOMENTUM = 0.1 @@ -96,8 +97,12 @@ def __init__( super(_ParameterListStats, self).__init__() self.stats_input_concat_dim = stats_input_concat_dim - self.first_tracked_param = _ViewParameterWrapper( - tracked_parameter_list[0], stats_input_view_shape_impl) + if len(tracked_parameter_list) >= 1: + self.first_tracked_param = _ViewParameterWrapper( + tracked_parameter_list[0], stats_input_view_shape_impl) + else: + self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl) + if len(tracked_parameter_list) > 1: extra_list = [ _ViewCatParameterWrapper( @@ -109,10 +114,12 @@ def __init__( self.stats = _Stats(stats_impl, stats_output_shape) @brevitas.jit.script_method - def forward(self) -> torch.Tensor: - stats_input = self.first_tracked_param() + def forward(self, x: Optional[torch.Tensor] = None) -> torch.Tensor: if self.extra_tracked_params_list is not None: + stats_input = self.first_tracked_param(None) for extra_tracked_param in self.extra_tracked_params_list: stats_input = extra_tracked_param(stats_input) + else: + stats_input = self.first_tracked_param(x) out = self.stats(stats_input) return out diff --git a/src/brevitas/core/stats/view_wrapper.py b/src/brevitas/core/stats/view_wrapper.py index acea542d9..98c6ab538 100644 --- a/src/brevitas/core/stats/view_wrapper.py +++ b/src/brevitas/core/stats/view_wrapper.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional + import torch from torch import Tensor from torch.nn import Module @@ -19,8 +21,12 @@ def __init__(self, parameter: Parameter, view_shape_impl: Module) -> None: self.view_shape_impl = view_shape_impl @brevitas.jit.script_method - def forward(self) -> Tensor: - return self.view_shape_impl(self.parameter) + def forward(self, x: Optional[Tensor]) -> Tensor: + if x is not None: + parameter = x + else: + parameter = self.parameter + return self.view_shape_impl(parameter) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, @@ -39,6 +45,17 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): return output_dict +class _ViewParameter(brevitas.jit.ScriptModule): + + def __init__(self, view_shape_impl: Module) -> None: + super(_ViewParameter, self).__init__() + self.view_shape_impl = view_shape_impl + + @brevitas.jit.script_method + def forward(self, x: Tensor) -> Tensor: + return self.view_shape_impl(x) + + class _ViewCatParameterWrapper(brevitas.jit.ScriptModule): __constants__ = ['cat_dim'] diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 3f80f1dd4..796940f4f 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -82,7 +82,7 @@ def __init__( @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) return self.scale_shift_zero_point(-stats, scale, bit_width) @@ -266,7 +266,7 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: value = self.scale_shift_zero_point(value, scale, bit_width) return value else: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index f40a367e1..75c08e7c0 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -285,6 +285,8 @@ def main(args): model = offload_model(model) + model(**calibration_loader[0]) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader)