We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9b23d6d commit 6675590Copy full SHA for 6675590
timm/models/vision_transformer.py
@@ -332,7 +332,17 @@ def __init__(
332
])))
333
334
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
335
- x = x + torch.stack([attn(x, attn_mask=attn_mask) for attn in self.attns]).sum(dim=0)
+ if attn_mask is not None:
336
+ attn_out = []
337
+ for attn in self.attns:
338
+ x_attn = attn.norm(x)
339
+ x_attn = attn.attn(x_attn, attn_mask=attn_mask)
340
+ x_attn = attn.ls(x_attn)
341
+ x_attn = attn.drop_path(x_attn)
342
+ attn_out.append(x_attn)
343
+ x = x + torch.stack(attn_out).sum(dim=0)
344
+ else:
345
+ x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
346
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
347
return x
348
0 commit comments