From fcb9501cdd9e056dd040915deb3e0a6378821843 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 28 Aug 2024 12:21:31 -0700 Subject: [PATCH] add register tokens to the nested tensor 3d na vit example for researcher --- setup.py | 2 +- vit_pytorch/na_vit_nested_tensor_3d.py | 28 ++++++++++++++++++++------ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index cf6ea2d..e93a580 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.7.11', + version = '1.7.12', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/na_vit_nested_tensor_3d.py b/vit_pytorch/na_vit_nested_tensor_3d.py index 8c77340..e40758c 100644 --- a/vit_pytorch/na_vit_nested_tensor_3d.py +++ b/vit_pytorch/na_vit_nested_tensor_3d.py @@ -163,6 +163,7 @@ def __init__( dim_head = 64, dropout = 0., emb_dropout = 0., + num_registers = 4, token_dropout_prob: float | None = None ): super().__init__() @@ -193,9 +194,18 @@ def __init__( nn.LayerNorm(dim), ) - self.pos_embed_frame = nn.Parameter(torch.randn(patch_frame_dim, dim)) - self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim)) - self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim)) + self.pos_embed_frame = nn.Parameter(torch.zeros(patch_frame_dim, dim)) + self.pos_embed_height = nn.Parameter(torch.zeros(patch_height_dim, dim)) + self.pos_embed_width = nn.Parameter(torch.zeros(patch_width_dim, dim)) + + # register tokens + + self.register_tokens = nn.Parameter(torch.zeros(num_registers, dim)) + + nn.init.normal_(self.pos_embed_frame, std = 0.02) + nn.init.normal_(self.pos_embed_height, std = 0.02) + nn.init.normal_(self.pos_embed_width, std = 0.02) + nn.init.normal_(self.register_tokens, std = 0.02) self.dropout = nn.Dropout(emb_dropout) @@ -275,8 +285,6 @@ def forward( pos_embed = frame_embed + height_embed + width_embed - # use nested tensor for transformers and save on padding computation - tokens = torch.cat(tokens) # linear projection to patch embeddings @@ -287,7 +295,15 @@ def forward( tokens = tokens + pos_embed - tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device) + # add register tokens + + tokens = tokens.split(seq_lens.tolist()) + + tokens = [torch.cat((self.register_tokens, one_tokens)) for one_tokens in tokens] + + # use nested tensor for transformers and save on padding computation + + tokens = nested_tensor(tokens, layout = torch.jagged, device = device) # embedding dropout