Skip to content

Commit

Permalink
go all the way with the normalized vit, fix some scales
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 10, 2024
1 parent 1d1a63f commit 36ddc7a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.8.2',
version = '1.8.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
21 changes: 11 additions & 10 deletions vit_pytorch/normalized_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,18 @@ def __init__(

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
NormLinear(patch_dim, dim, norm_dim_in = False),
)

self.abs_pos_emb = nn.Embedding(num_patches, dim)
self.abs_pos_emb = NormLinear(dim, num_patches)

residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)

# layers

self.dim = dim
self.scale = dim ** 0.5

self.layers = ModuleList([])
self.residual_lerp_scales = nn.ParameterList([])

Expand All @@ -201,8 +201,8 @@ def __init__(
]))

self.residual_lerp_scales.append(nn.ParameterList([
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
]))

self.logit_scale = nn.Parameter(torch.ones(num_classes))
Expand All @@ -225,22 +225,23 @@ def forward(self, images):

tokens = self.to_patch_embedding(images)

pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device))
seq_len = tokens.shape[-2]
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]

tokens = l2norm(tokens + pos_emb)

for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):

attn_out = l2norm(attn(tokens))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))

ff_out = l2norm(ff(tokens))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))

pooled = reduce(tokens, 'b n d -> b d', 'mean')

logits = self.to_pred(pooled)
logits = logits * self.logit_scale * (self.dim ** 0.5)
logits = logits * self.logit_scale * self.scale

return logits

Expand Down

0 comments on commit 36ddc7a

Please sign in to comment.