Open
Description
Hi, I was looking through this code for reimplementation for a separate task, but I noticed that the orthogonal regularization is implemented by adding the gradient of modified orthogonal regularization loss to the parameters. Shouldn't it be a subtraction for gradient descent. Appreciate any advice :)
I am looking at specifically this code snippet in utils.py
w = param.view(param.shape[0], -1)
grad = (2 * torch.mm(torch.mm(w, w.t())
* (1. - torch.eye(w.shape[0], device=w.device)), w))
param.grad.data += strength * grad.view(param.shape)
Metadata
Metadata
Assignees
Labels
No labels