Skip to content

Query about orthogonal regularization implementation #90

Open
@TanYingHao

Description

@TanYingHao

Hi, I was looking through this code for reimplementation for a separate task, but I noticed that the orthogonal regularization is implemented by adding the gradient of modified orthogonal regularization loss to the parameters. Shouldn't it be a subtraction for gradient descent. Appreciate any advice :)

I am looking at specifically this code snippet in utils.py
w = param.view(param.shape[0], -1)
grad = (2 * torch.mm(torch.mm(w, w.t())
* (1. - torch.eye(w.shape[0], device=w.device)), w))
param.grad.data += strength * grad.view(param.shape)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions