Skip to content

Commit b786029

Browse files
committed
fix the dimension per head to be independent of dim and heads, to make sure users do not have it be too small to learn anything
1 parent 9624181 commit b786029

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ v = ViT(
2424
num_classes = 1000,
2525
dim = 1024,
2626
depth = 6,
27-
heads = 8,
27+
heads = 16,
2828
mlp_dim = 2048,
2929
dropout = 0.1,
3030
emb_dropout = 0.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.3.0',
6+
version = '0.4.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/vit_pytorch.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ def forward(self, x):
3434
return self.net(x)
3535

3636
class Attention(nn.Module):
37-
def __init__(self, dim, heads = 8, dropout = 0.):
37+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
3838
super().__init__()
39+
inner_dim = dim_head * heads
3940
self.heads = heads
4041
self.scale = dim ** -0.5
4142

42-
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
43+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
4344
self.to_out = nn.Sequential(
44-
nn.Linear(dim, dim),
45+
nn.Linear(inner_dim, dim),
4546
nn.Dropout(dropout)
4647
)
4748

@@ -68,12 +69,12 @@ def forward(self, x, mask = None):
6869
return out
6970

7071
class Transformer(nn.Module):
71-
def __init__(self, dim, depth, heads, mlp_dim, dropout):
72+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
7273
super().__init__()
7374
self.layers = nn.ModuleList([])
7475
for _ in range(depth):
7576
self.layers.append(nn.ModuleList([
76-
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
77+
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
7778
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
7879
]))
7980
def forward(self, x, mask = None):
@@ -83,7 +84,7 @@ def forward(self, x, mask = None):
8384
return x
8485

8586
class ViT(nn.Module):
86-
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
87+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
8788
super().__init__()
8889
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
8990
num_patches = (image_size // patch_size) ** 2
@@ -97,7 +98,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
9798
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
9899
self.dropout = nn.Dropout(emb_dropout)
99100

100-
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
101+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
101102

102103
self.to_cls_token = nn.Identity()
103104

0 commit comments

Comments
 (0)