Skip to content

Commit 10dcee3

Browse files
authored
Feat (gptq): optimizing CPU to GPU memory transfer (#1009)
1 parent 9932b92 commit 10dcee3

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/brevitas/graph/gptq.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,12 @@ def __init__(
132132
# Initialize Hessian matrix and counter. We need it in float32 to compute the inverse
133133
self.H = torch.zeros((self.groups, self.columns, self.columns),
134134
device='cpu',
135-
dtype=torch.float32)
135+
dtype=torch.float32,
136+
pin_memory=torch.cuda.is_available())
137+
self.B = torch.zeros((self.groups, self.columns, self.columns),
138+
device='cpu',
139+
dtype=torch.float32,
140+
pin_memory=torch.cuda.is_available())
136141
self.nsamples = 0
137142

138143
assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"
@@ -184,7 +189,9 @@ def update_batch(self, module, input, current_layer):
184189
self.H *= self.nsamples / (self.nsamples + batch_size)
185190
self.nsamples += batch_size
186191
inp_processed = math.sqrt(2 / self.nsamples) * inp_processed.to(torch.float32)
187-
self.H += (inp_processed.bmm(inp_processed.transpose(2, 1))).to(self.H.device)
192+
# optimizing CPU to GPU transfer using in-place copy to pinned memory
193+
self.B.copy_(inp_processed.bmm(inp_processed.transpose(2, 1)))
194+
self.H += self.B
188195
# If we are executing GPTQ with group of parallel layers, we keep track of how many forward
189196
# we executed. Once we executed as many as the number of parallel_layers, we raise
190197
# StopFwdException
@@ -255,7 +262,7 @@ def single_layer_update(self, percdamp=.01):
255262
f'Increasing the number of samples might fix this issue')
256263
return
257264
finally:
258-
del self.H
265+
del self.H, self.B
259266

260267
for i1 in range(0, self.columns, self.blocksize):
261268
i2 = min(i1 + self.blocksize, self.columns)

0 commit comments

Comments
 (0)