Skip to content

Commit a150ef8

Browse files
author
Drew Penney
committed
Change to block attention behavior --- should allow block attention on final items even if < block size
1 parent 4152d08 commit a150ef8

File tree

2 files changed

+11
-23
lines changed

2 files changed

+11
-23
lines changed

fairseq/models/speech_to_text/convtransformer.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -335,33 +335,26 @@ def infer_conv_output_dim(self, in_channels, input_dim, out_channels):
335335

336336
def buffered_future_mask(self, tensor): # DP
337337
dim = tensor.size(0)
338+
delay = self.encoder_mask_future_delay
339+
block_size = self.encoder_mask_block_size
338340

339-
if (self.encoder_mask_future_delay >= dim-1): # Full attention allowed, no need to check other conditions
341+
if (delay >= dim-1) or (block_size >= dim): # Full attention allowed, no need to check other conditions
340342
self._future_mask = torch.zeros([dim, dim])
341343
else: # Start with mask that disallows looking into future
342344
tri_mask = torch.triu(
343345
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
344346
)
345347

346-
delay = self.encoder_mask_future_delay
347-
block_size = self.encoder_mask_block_size
348-
block_count = dim // block_size
349-
block_pad = dim % block_size
350-
blocks = torch.full((block_count, block_size, block_size), 1, dtype=torch.bool)
351-
352348
# Create additional masks that consider self.encoder_mask_future_delay and self.encoder_mask_block_size
353-
block_mask = torch.nn.functional.pad(input=torch.block_diag(*blocks), pad=(0, block_pad, 0, block_pad))
349+
block_count = math.ceil(dim / block_size)
350+
blocks = torch.full((block_count, block_size, block_size), 1, dtype=torch.bool)
351+
block_mask = torch.nn.functional.pad(input=torch.block_diag(*blocks), pad=(0, 0, 0, 0))[:dim,:dim]
354352
delay_mask = torch.cat(
355353
(
356354
torch.full((dim,delay+1), 1, dtype=torch.bool),
357355
torch.zeros( (dim,dim-(delay+1)), dtype=torch.bool)
358356
), 1
359357
)
360-
361-
# VA, covers edge case where dim is less than block_size and the block_mask logic is a dimension off
362-
if dim < block_size:
363-
block_mask = block_mask[:-1]
364-
365358
corr_mask = torch.logical_or(block_mask, delay_mask)
366359

367360
self._future_mask = tri_mask.masked_fill_(corr_mask, 0) # Apply correction

tests/test_mask.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
import torch
33
import argparse
4+
import math
45

56
parser = argparse.ArgumentParser()
67
parser.add_argument(
@@ -28,29 +29,23 @@
2829
#print(block_size)
2930
dim = 6
3031

31-
if (delay >= dim-1): # Full attention allowed, no need to check other conditions
32+
if (delay >= dim-1) or (block_size >= dim): # Full attention allowed, no need to check other conditions
3233
future_mask = torch.zeros([dim, dim])
3334
else:
3435
tri_mask = torch.triu( # Start with mask that disallows looking into future
3536
torch.full((dim,dim), -999), 1
3637
)
38+
3739
# Create additional masks that consider self.encoder_mask_future_delay and self.encoder_block_size
38-
block_count = dim // block_size
39-
block_pad = dim % block_size
40+
block_count = math.ceil(dim / block_size)
4041
blocks = torch.full((block_count, block_size, block_size), 1, dtype=torch.bool)
41-
block_mask = torch.nn.functional.pad(input=torch.block_diag(*blocks), pad=(0, block_pad, 0, block_pad))
42-
42+
block_mask = torch.nn.functional.pad(input=torch.block_diag(*blocks), pad=(0, 0, 0, 0))[:dim,:dim]
4343
delay_mask = torch.cat(
4444
(
4545
torch.full((dim,delay+1), 1, dtype=torch.bool),
4646
torch.zeros((dim,dim-(delay+1)), dtype=torch.bool)
4747
), 1
4848
)
49-
50-
# VA, covers edge case where dim is less than block_size and the block_mask logic is a dimension off
51-
if dim < block_size:
52-
block_mask = block_mask[:-1]
53-
5449
corr_mask = torch.logical_or(block_mask, delay_mask)
5550

5651
print(f"Block mask:\n{block_mask}")

0 commit comments

Comments
 (0)