-
Notifications
You must be signed in to change notification settings - Fork 95
Open
Description
In 'main.py' line 257 - 262, the author used the following codes to freeze the pruned weights:
for name, p in model.named_parameters():
if 'weight' in name:
tensor = p.data.cpu().numpy()
grad_tensor = p.grad.data.cpu().numpy()
grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
p.grad.data = torch.from_numpy(grad_tensor).to(device)
which causes a heavy burden for CPU2GPU I/O.
I will recommend conducting the freezing operation on GPU directly, the following codes helps:
for name, p in model.named_parameters():
if 'weight' in name:
tensor = p.data
grad_tensor = p.grad
grad_tensor = torch.where(tensor.abs() < EPS, torch.zeros_like(grad_tensor), grad_tensor)
p.grad.data = grad_tensor
blefaudeux, bainro, alrafiabdullah, l1teng and ChenDaiwei-99
Metadata
Metadata
Assignees
Labels
No labels