Skip to content

Commit 7586733

Browse files
committed
wip: alignment context
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 559ad81 commit 7586733

File tree

5 files changed

+67
-24
lines changed

5 files changed

+67
-24
lines changed

examples/quantization_w4a16/llama3_example.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
from compressed_tensors import force_cpu_offload
13
from datasets import load_dataset
24
from transformers import AutoModelForCausalLM, AutoTokenizer
35

@@ -9,9 +11,10 @@
911

1012
model = AutoModelForCausalLM.from_pretrained(
1113
MODEL_ID,
12-
device_map="auto",
14+
# device_map="auto",
1315
torch_dtype="auto",
1416
)
17+
force_cpu_offload(model, execution_device=torch.device("cuda"))
1518
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
1619

1720
# Select calibration dataset.

src/llmcompressor/pipelines/basic/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def __call__(
3737
:param dataloader: loads data for calibration
3838
:param dataset_args: dataset arguments relevant to pipelines
3939
"""
40+
# TODO: warn about cpu offloading
41+
4042
model_device = get_execution_device(model)
4143

4244
LifecycleCallbacks.calibration_epoch_start()

src/llmcompressor/pipelines/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ def _validate_infer_pipeline(modifiers: List[Modifier]) -> str:
7575
quant_modifier = active_qmods[0]
7676
config = quant_modifier.resolve_quantization_config()
7777
if config.requires_calibration_data():
78-
return "basic"
78+
return "sequential"
7979
else:
8080
return "datafree"
8181

8282
if any(isinstance(modifier, SmoothQuantModifier) for modifier in modifiers):
83-
return "basic"
83+
return "sequential"
8484

8585
return "datafree"
8686

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict, List, Optional, Set
66

77
import torch
8+
from accelerate.hooks import AlignDevicesHook
89
from compressed_tensors import has_offloaded_params
910
from compressed_tensors.quantization import find_name_or_class_matches
1011
from loguru import logger
@@ -23,7 +24,12 @@
2324

2425
from .ast_helpers import autowrap_forwards
2526

26-
__all__ = ["trace_subgraphs", "Subgraph", "get_targets_from_modifiers"]
27+
__all__ = [
28+
"trace_subgraphs",
29+
"Subgraph",
30+
"get_targets_from_modifiers",
31+
"keep_onload_context",
32+
]
2733

2834

2935
@dataclass
@@ -485,3 +491,30 @@ def is_ancestor(module: Module) -> bool:
485491

486492
is_ancestor(model)
487493
return ancestors
494+
495+
496+
@contextlib.contextmanager
497+
def keep_onload_context():
498+
original_pre_forward = AlignDevicesHook.pre_forward
499+
onloaded_modules = dict()
500+
501+
# onload once and disable any future onloading/offloading steps
502+
def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
503+
ret = original_pre_forward(self, module, *args, **kwargs)
504+
if module not in onloaded_modules:
505+
onloaded_modules[module] = (self, self.offload)
506+
self.offload = False
507+
return ret
508+
509+
# use the patched pre_forward function within the context
510+
with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
511+
yield
512+
513+
# manually offload all modules that were onloaded
514+
for module, (hook, offload) in onloaded_modules.items():
515+
hook.offload = offload
516+
hook.post_forward(module, None)
517+
518+
519+
# def is_cpu_offloaded():
520+
#

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from typing import TYPE_CHECKING
22

33
import torch
4-
import tqdm
54
from compressed_tensors.utils import get_execution_device
65
from torch.utils.data.dataloader import DataLoader
6+
from tqdm import tqdm
77

88
from llmcompressor.core import LifecycleCallbacks, active_session
99
from llmcompressor.modifiers.utils.hooks import HooksMixin
1010
from llmcompressor.pipelines.cache import IntermediatesCache
1111
from llmcompressor.pipelines.registry import CalibrationPipeline
1212
from llmcompressor.pipelines.sequential.helpers import (
1313
get_targets_from_modifiers,
14+
keep_onload_context,
1415
trace_subgraphs,
1516
)
1617
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
@@ -51,6 +52,8 @@ def __call__(
5152
"""
5253
session = active_session()
5354

55+
# TODO: warn about not cpu offloading
56+
5457
# prepare to trace subgraphs
5558
modifiers = session.get_modifiers()
5659
sequential_targets = get_targets_from_modifiers(modifiers, model)
@@ -59,37 +62,39 @@ def __call__(
5962
# trace subgraphs
6063
sample_input = next(iter(dataloader))
6164
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
65+
num_subgraphs = len(subgraphs)
6266

6367
LifecycleCallbacks.calibration_epoch_start()
6468

6569
with calibration_forward_context(model), DisableQuantization(model):
6670
# prepare intermediates cache
6771
model_device = get_execution_device(model)
68-
intermediates = IntermediatesCache.from_dataloader(dataloader, model_device)
72+
activations = IntermediatesCache.from_dataloader(dataloader, model_device)
6973

70-
num_subgraphs = len(subgraphs)
7174
for subgraph_index, subgraph in enumerate(subgraphs):
7275
# prepare tqdm description texts
7376
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
7477
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"
7578

76-
# do a preliminary pass to trigger modifier hooks
77-
for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
78-
inputs = intermediates.fetch(batch_idx, subgraph.input_names)
79-
subgraph.forward(model, **inputs)
80-
81-
LifecycleCallbacks.sequential_epoch_end()
82-
83-
# this pass does not trigger modifier hooks
84-
# and is only used for capturing outputs from newly compressed modules
85-
with HooksMixin.disable_hooks():
86-
for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
87-
inputs = intermediates.fetch(batch_idx, subgraph.input_names)
88-
output = subgraph.forward(model, **inputs)
89-
90-
if subgraph_index < num_subgraphs - 1:
91-
intermediates.update(batch_idx, output)
92-
intermediates.delete(batch_idx, subgraph.consumed_names)
79+
# reduce memory movement by keeping modules onloaded
80+
with keep_onload_context():
81+
# do a preliminary pass to trigger modifier hooks
82+
for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc):
83+
inputs = activations.fetch(batch_idx, subgraph.input_names)
84+
subgraph.forward(model, **inputs)
85+
86+
LifecycleCallbacks.sequential_epoch_end()
87+
88+
# this pass does not trigger modifier hooks
89+
# and is only used for capturing outputs of newly compressed modules
90+
with HooksMixin.disable_hooks():
91+
for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc):
92+
inputs = activations.fetch(batch_idx, subgraph.input_names)
93+
output = subgraph.forward(model, **inputs)
94+
95+
if subgraph_index < num_subgraphs - 1:
96+
activations.update(batch_idx, output)
97+
activations.delete(batch_idx, subgraph.consumed_names)
9398

9499
# redundant, finish any remaining compression
95100
LifecycleCallbacks.calibration_epoch_end()

0 commit comments

Comments
 (0)