@@ -132,7 +132,12 @@ def __init__(
132
132
# Initialize Hessian matrix and counter. We need it in float32 to compute the inverse
133
133
self .H = torch .zeros ((self .groups , self .columns , self .columns ),
134
134
device = 'cpu' ,
135
- dtype = torch .float32 )
135
+ dtype = torch .float32 ,
136
+ pin_memory = torch .cuda .is_available ())
137
+ self .B = torch .zeros ((self .groups , self .columns , self .columns ),
138
+ device = 'cpu' ,
139
+ dtype = torch .float32 ,
140
+ pin_memory = torch .cuda .is_available ())
136
141
self .nsamples = 0
137
142
138
143
assert torch_version >= version .parse ('1.10' ), "GPTQ requires torch 1.10 or higher"
@@ -184,7 +189,9 @@ def update_batch(self, module, input, current_layer):
184
189
self .H *= self .nsamples / (self .nsamples + batch_size )
185
190
self .nsamples += batch_size
186
191
inp_processed = math .sqrt (2 / self .nsamples ) * inp_processed .to (torch .float32 )
187
- self .H += (inp_processed .bmm (inp_processed .transpose (2 , 1 ))).to (self .H .device )
192
+ # optimizing CPU to GPU transfer using in-place copy to pinned memory
193
+ self .B .copy_ (inp_processed .bmm (inp_processed .transpose (2 , 1 )))
194
+ self .H += self .B
188
195
# If we are executing GPTQ with group of parallel layers, we keep track of how many forward
189
196
# we executed. Once we executed as many as the number of parallel_layers, we raise
190
197
# StopFwdException
@@ -255,7 +262,7 @@ def single_layer_update(self, percdamp=.01):
255
262
f'Increasing the number of samples might fix this issue' )
256
263
return
257
264
finally :
258
- del self .H
265
+ del self .H , self . B
259
266
260
267
for i1 in range (0 , self .columns , self .blocksize ):
261
268
i2 = min (i1 + self .blocksize , self .columns )
0 commit comments