-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Description
To compute the L2 norm in cosine similarity, F.normalize should be used. However, the last dimension (dim=-1) is a padded dimension of size 1. Using cosine similarity on dim=-1 is incorrect, as this leads to all values becoming 1 or -1. The correct approach is to use dim=-2, to ensure the code aligns with the paper.
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # shape (3, B_, heads, N, 1)
q, k, v = qkv[0], qkv[1], qkv[2]
q_, k_ = F.normalize(q, dim=-1).detach(), F.normalize(k, dim=-1).detach()
print('shape', q_.shape, k_.shape) # print shape
print('value', q_.numpy().reshape(-1), k_.numpy().reshape(-1)) # print normalized values (-1 or 1)
# cosine attention
# dim=-1 is incorrect, should be corrected to dim=-2
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
attn = attn * logit_scaleMetadata
Metadata
Assignees
Labels
No labels
