Skip to content

Freeze pruned weights method not efficient #10

@guoyuntu

Description

@guoyuntu

In 'main.py' line 257 - 262, the author used the following codes to freeze the pruned weights:

for name, p in model.named_parameters():
        if 'weight' in name:
            tensor = p.data.cpu().numpy()
            grad_tensor = p.grad.data.cpu().numpy()
            grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
            p.grad.data = torch.from_numpy(grad_tensor).to(device)

which causes a heavy burden for CPU2GPU I/O.
I will recommend conducting the freezing operation on GPU directly, the following codes helps:

    for name, p in model.named_parameters():
        if 'weight' in name:
            tensor = p.data
            grad_tensor = p.grad
            grad_tensor = torch.where(tensor.abs() < EPS, torch.zeros_like(grad_tensor), grad_tensor)
            p.grad.data = grad_tensor

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions