1
1
from typing import TYPE_CHECKING
2
2
3
3
import torch
4
- import tqdm
5
4
from compressed_tensors .utils import get_execution_device
6
5
from torch .utils .data .dataloader import DataLoader
6
+ from tqdm import tqdm
7
7
8
8
from llmcompressor .core import LifecycleCallbacks , active_session
9
9
from llmcompressor .modifiers .utils .hooks import HooksMixin
10
10
from llmcompressor .pipelines .cache import IntermediatesCache
11
11
from llmcompressor .pipelines .registry import CalibrationPipeline
12
12
from llmcompressor .pipelines .sequential .helpers import (
13
13
get_targets_from_modifiers ,
14
+ keep_onload_context ,
14
15
trace_subgraphs ,
15
16
)
16
17
from llmcompressor .utils .helpers import DisableQuantization , calibration_forward_context
@@ -51,6 +52,8 @@ def __call__(
51
52
"""
52
53
session = active_session ()
53
54
55
+ # TODO: warn about not cpu offloading
56
+
54
57
# prepare to trace subgraphs
55
58
modifiers = session .get_modifiers ()
56
59
sequential_targets = get_targets_from_modifiers (modifiers , model )
@@ -59,37 +62,39 @@ def __call__(
59
62
# trace subgraphs
60
63
sample_input = next (iter (dataloader ))
61
64
subgraphs = trace_subgraphs (model , sample_input , sequential_targets , ignore )
65
+ num_subgraphs = len (subgraphs )
62
66
63
67
LifecycleCallbacks .calibration_epoch_start ()
64
68
65
69
with calibration_forward_context (model ), DisableQuantization (model ):
66
70
# prepare intermediates cache
67
71
model_device = get_execution_device (model )
68
- intermediates = IntermediatesCache .from_dataloader (dataloader , model_device )
72
+ activations = IntermediatesCache .from_dataloader (dataloader , model_device )
69
73
70
- num_subgraphs = len (subgraphs )
71
74
for subgraph_index , subgraph in enumerate (subgraphs ):
72
75
# prepare tqdm description texts
73
76
calib_desc = f"({ subgraph_index + 1 } /{ num_subgraphs } ): Calibrating"
74
77
prop_desc = f"({ subgraph_index + 1 } /{ num_subgraphs } ): Propagating"
75
78
76
- # do a preliminary pass to trigger modifier hooks
77
- for batch_idx in tqdm .tqdm (range (len (dataloader )), desc = calib_desc ):
78
- inputs = intermediates .fetch (batch_idx , subgraph .input_names )
79
- subgraph .forward (model , ** inputs )
80
-
81
- LifecycleCallbacks .sequential_epoch_end ()
82
-
83
- # this pass does not trigger modifier hooks
84
- # and is only used for capturing outputs from newly compressed modules
85
- with HooksMixin .disable_hooks ():
86
- for batch_idx in tqdm .tqdm (range (len (dataloader )), desc = prop_desc ):
87
- inputs = intermediates .fetch (batch_idx , subgraph .input_names )
88
- output = subgraph .forward (model , ** inputs )
89
-
90
- if subgraph_index < num_subgraphs - 1 :
91
- intermediates .update (batch_idx , output )
92
- intermediates .delete (batch_idx , subgraph .consumed_names )
79
+ # reduce memory movement by keeping modules onloaded
80
+ with keep_onload_context ():
81
+ # do a preliminary pass to trigger modifier hooks
82
+ for batch_idx in tqdm (range (len (dataloader )), desc = calib_desc ):
83
+ inputs = activations .fetch (batch_idx , subgraph .input_names )
84
+ subgraph .forward (model , ** inputs )
85
+
86
+ LifecycleCallbacks .sequential_epoch_end ()
87
+
88
+ # this pass does not trigger modifier hooks
89
+ # and is only used for capturing outputs of newly compressed modules
90
+ with HooksMixin .disable_hooks ():
91
+ for batch_idx in tqdm (range (len (dataloader )), desc = prop_desc ):
92
+ inputs = activations .fetch (batch_idx , subgraph .input_names )
93
+ output = subgraph .forward (model , ** inputs )
94
+
95
+ if subgraph_index < num_subgraphs - 1 :
96
+ activations .update (batch_idx , output )
97
+ activations .delete (batch_idx , subgraph .consumed_names )
93
98
94
99
# redundant, finish any remaining compression
95
100
LifecycleCallbacks .calibration_epoch_end ()
0 commit comments