Skip to content

Commit 6675590

Browse files
committed
Fix ParallelThingsBlock w/ attn_mask
1 parent 9b23d6d commit 6675590

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

timm/models/vision_transformer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,17 @@ def __init__(
332332
])))
333333

334334
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
335-
x = x + torch.stack([attn(x, attn_mask=attn_mask) for attn in self.attns]).sum(dim=0)
335+
if attn_mask is not None:
336+
attn_out = []
337+
for attn in self.attns:
338+
x_attn = attn.norm(x)
339+
x_attn = attn.attn(x_attn, attn_mask=attn_mask)
340+
x_attn = attn.ls(x_attn)
341+
x_attn = attn.drop_path(x_attn)
342+
attn_out.append(x_attn)
343+
x = x + torch.stack(attn_out).sum(dim=0)
344+
else:
345+
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
336346
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
337347
return x
338348

0 commit comments

Comments
 (0)