Skip to content

Commit dce5e81

Browse files
authored
Implement QuantizationMixin (#1351)
## Purpose ## * Abstract functionality which allows modifiers to act as quantization configs into a mixin called `QuantizationMixin` * This gives #1279 an interface to properly infer which pipeline to use based on the recipe (if a recipe contains modifiers requires calibration, then use the "basic" or "sequential" pipelines) * This enables future modifiers to act as quantization modifiers (in the same way that GPTQ does now) * Related to #1354 where previous logic would attempt to add a QuantizedKVCache for dynamic kv_quant ## Changes ## * Implement `QuantizationMixin` which implements five public methods * Lifecycle methods * `initialize_quantization` is used to apply a config and attach observers to a model * quantization is disabled so that modules aren't quantized before they're calibrated * `start_calibration` is used to initialize calibration hooks and status * quantization is enabled, since we currently quantize as we calibrate, although this decision is somewhat arbitrary * `end_calibration` is used to remove calibration hooks and apply the frozen status * quantization remains enabled, since we want future forward passes to simulate quantization * Recipe-related methods * `has_config` returns true if a config was specified, used for checking against duplicate configs in the recipe * `resolve_quantization_config` returns the quantization config specified by the modifier fields * `QuantizationModifier` inherits from `QuantizationMixin` * `GPTQModifier` inherits from `QuantizationMixin` * Unlike QMod, GPTQ disables quantization during calibration. As noted before, this is a somewhat arbitrary choice but one which matches the current implementation * Calibration utils * Replace `set_unset_kv_cache` with `initialize_quantized_kv_cache` and `freeze_module_quantization` * Treat the `QuantizedKVCache` as analogous to another observer * Pull setting the calibration status out of`update_weight_zp_scale` * This better matches the lifecycle detailed in `QuantizationMixin` description * Implement `reset_quantization_status` which is used to remove any existing quantization configs before the current config is applied by `initialize_quantization` ## Remove Support ## * Removing support for recipe with multiple quantization modifiers active at the same time (a check for this will be added by #1279) * Remove `num_calibration_steps`, `quantize`, `disable_quantization_observer_epoch` and `min_tokens_per_module` * `num_calibration_steps` is already controlled by https://github.com/vllm-project/llm-compressor/blob/42b62f5283d0234b26623fe1f1bf02a77c6e4019/src/llmcompressor/datasets/utils.py#L106 * `quantize` was implemented as a workaround for GPTQ's modifier builder. Similar functionality may be require to support SpinQuant + GPTQ, but such functionality should exist at a higher level * `disable_quantization_observer_epoch` seems to implement functionality where a model's observers are removed but quantization remains active. This functionality is maintained by setting an "end" epoch for qmod * `min_tokens_per_module` requires that the modifier have references to the calibration dataset, which is disallowed by #1279. This information is already printed in GPTQ's logs. If research still wants this tool specifically for `QuantizationModifier`, then it can be reimplemented to avoid using references to the calibration dataset ## Testing ## * Updated tests to reflect new mixin * Ran a set of GPTQ and QuantizationModifier examples to completion * CI tests pass --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent e168e3a commit dce5e81

File tree

8 files changed

+436
-582
lines changed

8 files changed

+436
-582
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

+36-31
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch
44
from compressed_tensors.quantization import (
55
KVCacheScaleType,
6+
QuantizationScheme,
67
QuantizationStatus,
7-
is_attention_module,
88
)
99
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
1010
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
@@ -14,6 +14,7 @@
1414

1515
from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache
1616
from llmcompressor.observers import Observer
17+
from llmcompressor.utils.helpers import getattr_chain
1718

1819
__all__ = [
1920
"initialize_observer",
@@ -22,9 +23,10 @@
2223
"calibrate_output_hook",
2324
"calibrate_kv_cache_input_hook",
2425
"calibrate_kv_cache_output_hook",
25-
"set_unset_kv_cache",
26+
"initialize_quantized_kv_cache",
2627
"freeze_module_quantization",
2728
"apply_calibration_status",
29+
"reset_quantization_status",
2830
]
2931

3032

@@ -49,10 +51,6 @@ def initialize_observer(
4951
# no quantization scheme nothing to do
5052
return
5153

52-
# observers have a different lifecycle for kv_cache
53-
if is_attention_module(module):
54-
return
55-
5654
quantization_args = getattr(quantization_scheme, arg_name, None)
5755
# dont need observers for dynamic
5856
if quantization_args is not None and not quantization_args.dynamic:
@@ -102,25 +100,15 @@ def update_weight_zp_scale(module: Module):
102100
:param quantize_weights_upfront: whether to automatically
103101
run weight quantization at the start of calibration
104102
"""
105-
if not getattr(module, "quantization_scheme", None):
106-
# no quantization scheme nothing to do
103+
if getattr_chain(module, "quantization_scheme.weights", None) is None:
107104
return
108105

109-
status = getattr(module, "quantization_status", None)
110-
if not status:
111-
# not set to initialize; no scales/zp to update
112-
return
113-
if status != QuantizationStatus.INITIALIZED:
106+
if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION:
114107
logger.warning(
115-
f"Attempting set module with status {status} to calibration mode. "
116-
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
117-
"be calibrating an uninitialized module which may fail or attempting "
118-
"to re-calibrate a frozen module"
108+
"Attempting to calibrate weights of a module not in calibration mode"
119109
)
120110

121-
if module.quantization_scheme.weights is not None:
122-
# set weight scale and zero_point up front, calibration data doesn't affect it
123-
call_observer(module=module, base_name="weight")
111+
call_observer(module=module, base_name="weight")
124112

125113

126114
def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
@@ -200,21 +188,26 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
200188
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)
201189

202190

203-
def set_unset_kv_cache(module: Module):
191+
def initialize_quantized_kv_cache(module: Module):
204192
"""
205-
Set or unset singleton QuantizedKVParameterCache for each
206-
attn module when running kv_cache quantization.
193+
Initialize a quantized kv_cache on a module (analogous to initializing an observer)
194+
When a config specifying kv_cache quantization is applied to a model, the kv_cache
195+
args are redefined as the output_activations targeting attention modules.
196+
197+
This function should be called on attention modules with output_activations
207198
"""
208-
if not hasattr(module, "quantization_scheme"):
199+
scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None)
200+
existing_kv_cache = getattr(module, "kv_cache", None)
201+
202+
if (
203+
scheme is None
204+
or not is_kv_cache_quant_scheme(scheme)
205+
or isinstance(existing_kv_cache, QuantizedKVParameterCache)
206+
):
209207
return
210208

211-
if is_kv_cache_quant_scheme(module.quantization_scheme):
212-
output_args = module.quantization_scheme.output_activations
213-
kv_cache = QuantizedKVParameterCache(output_args)
214-
if hasattr(module, "kv_cache"):
215-
delattr(module, "kv_cache")
216-
else:
217-
setattr(module, "kv_cache", kv_cache)
209+
quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations)
210+
setattr(module, "kv_cache", quantized_kv_cache)
218211

219212

220213
def apply_calibration_status(module: Module):
@@ -242,9 +235,21 @@ def freeze_module_quantization(module: Module):
242235
# nothing to do, already frozen
243236
return
244237

238+
# remove observers
245239
for name in ("input", "weight", "output"):
246240
obs_name = f"{name}_observer"
247241
if hasattr(module, obs_name):
248242
delattr(module, obs_name)
249243

244+
# remove quantized kv_cache
245+
kv_cache = getattr(module, "kv_cache", None)
246+
if isinstance(kv_cache, QuantizedKVParameterCache):
247+
delattr(module, "kv_cache")
248+
250249
module.quantization_status = QuantizationStatus.FROZEN
250+
251+
252+
def reset_quantization_status(model: Module):
253+
for module in model.modules():
254+
if hasattr(module, "quantization_status"):
255+
delattr(module, "quantization_status")

0 commit comments

Comments
 (0)