Skip to content

Commit

Permalink
fix the dimension per head to be independent of dim and heads, to mak…
Browse files Browse the repository at this point in the history
…e sure users do not have it be too small to learn anything
  • Loading branch information
lucidrains committed Dec 17, 2020
1 parent 9624181 commit b786029
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ v = ViT(
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.3.0',
version = '0.4.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
15 changes: 8 additions & 7 deletions vit_pytorch/vit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim ** -0.5

self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

Expand All @@ -68,12 +69,12 @@ def forward(self, x, mask = None):
return out

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
Expand All @@ -83,7 +84,7 @@ def forward(self, x, mask = None):
return x

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
Expand All @@ -97,7 +98,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

self.to_cls_token = nn.Identity()

Expand Down

0 comments on commit b786029

Please sign in to comment.