Skip to content

Commit

Permalink
add register tokens to the nested tensor 3d na vit example for resear…
Browse files Browse the repository at this point in the history
…cher
  • Loading branch information
lucidrains committed Aug 28, 2024
1 parent c4651a3 commit fcb9501
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.11',
version = '1.7.12',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
28 changes: 22 additions & 6 deletions vit_pytorch/na_vit_nested_tensor_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
num_registers = 4,
token_dropout_prob: float | None = None
):
super().__init__()
Expand Down Expand Up @@ -193,9 +194,18 @@ def __init__(
nn.LayerNorm(dim),
)

self.pos_embed_frame = nn.Parameter(torch.randn(patch_frame_dim, dim))
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
self.pos_embed_frame = nn.Parameter(torch.zeros(patch_frame_dim, dim))
self.pos_embed_height = nn.Parameter(torch.zeros(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.zeros(patch_width_dim, dim))

# register tokens

self.register_tokens = nn.Parameter(torch.zeros(num_registers, dim))

nn.init.normal_(self.pos_embed_frame, std = 0.02)
nn.init.normal_(self.pos_embed_height, std = 0.02)
nn.init.normal_(self.pos_embed_width, std = 0.02)
nn.init.normal_(self.register_tokens, std = 0.02)

self.dropout = nn.Dropout(emb_dropout)

Expand Down Expand Up @@ -275,8 +285,6 @@ def forward(

pos_embed = frame_embed + height_embed + width_embed

# use nested tensor for transformers and save on padding computation

tokens = torch.cat(tokens)

# linear projection to patch embeddings
Expand All @@ -287,7 +295,15 @@ def forward(

tokens = tokens + pos_embed

tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device)
# add register tokens

tokens = tokens.split(seq_lens.tolist())

tokens = [torch.cat((self.register_tokens, one_tokens)) for one_tokens in tokens]

# use nested tensor for transformers and save on padding computation

tokens = nested_tensor(tokens, layout = torch.jagged, device = device)

# embedding dropout

Expand Down

0 comments on commit fcb9501

Please sign in to comment.