@@ -34,14 +34,15 @@ def forward(self, x):
34
34
return self .net (x )
35
35
36
36
class Attention (nn .Module ):
37
- def __init__ (self , dim , heads = 8 , dropout = 0. ):
37
+ def __init__ (self , dim , heads = 8 , dim_head = 64 , dropout = 0. ):
38
38
super ().__init__ ()
39
+ inner_dim = dim_head * heads
39
40
self .heads = heads
40
41
self .scale = dim ** - 0.5
41
42
42
- self .to_qkv = nn .Linear (dim , dim * 3 , bias = False )
43
+ self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
43
44
self .to_out = nn .Sequential (
44
- nn .Linear (dim , dim ),
45
+ nn .Linear (inner_dim , dim ),
45
46
nn .Dropout (dropout )
46
47
)
47
48
@@ -68,12 +69,12 @@ def forward(self, x, mask = None):
68
69
return out
69
70
70
71
class Transformer (nn .Module ):
71
- def __init__ (self , dim , depth , heads , mlp_dim , dropout ):
72
+ def __init__ (self , dim , depth , heads , dim_head , mlp_dim , dropout ):
72
73
super ().__init__ ()
73
74
self .layers = nn .ModuleList ([])
74
75
for _ in range (depth ):
75
76
self .layers .append (nn .ModuleList ([
76
- Residual (PreNorm (dim , Attention (dim , heads = heads , dropout = dropout ))),
77
+ Residual (PreNorm (dim , Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout ))),
77
78
Residual (PreNorm (dim , FeedForward (dim , mlp_dim , dropout = dropout )))
78
79
]))
79
80
def forward (self , x , mask = None ):
@@ -83,7 +84,7 @@ def forward(self, x, mask = None):
83
84
return x
84
85
85
86
class ViT (nn .Module ):
86
- def __init__ (self , * , image_size , patch_size , num_classes , dim , depth , heads , mlp_dim , channels = 3 , dropout = 0. , emb_dropout = 0. ):
87
+ def __init__ (self , * , image_size , patch_size , num_classes , dim , depth , heads , mlp_dim , channels = 3 , dim_head = 64 , dropout = 0. , emb_dropout = 0. ):
87
88
super ().__init__ ()
88
89
assert image_size % patch_size == 0 , 'Image dimensions must be divisible by the patch size.'
89
90
num_patches = (image_size // patch_size ) ** 2
@@ -97,7 +98,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
97
98
self .cls_token = nn .Parameter (torch .randn (1 , 1 , dim ))
98
99
self .dropout = nn .Dropout (emb_dropout )
99
100
100
- self .transformer = Transformer (dim , depth , heads , mlp_dim , dropout )
101
+ self .transformer = Transformer (dim , depth , heads , dim_head , mlp_dim , dropout )
101
102
102
103
self .to_cls_token = nn .Identity ()
103
104
0 commit comments