Skip to content

Commit c9497ca

Browse files
committedOct 30, 2023
fixed bug
1 parent 7608496 commit c9497ca

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed
 

‎fairseq/modules/waitseg_multihead_attention.py

+23
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,29 @@ def build_waitseg_mask(self, attn_weights, training_lagging_seg, seg_prob):
388388
cur_seg_num = torch.cumsum(seg_prob.round(), dim=-1)
389389
return cur_seg_num > idx
390390

391+
def build_waitk_mask(self, attn_weights, training_lagging_seg):
392+
bsz, tgt_len, src_len = attn_weights.size()
393+
idx = (
394+
torch.arange(
395+
training_lagging_seg - 1,
396+
training_lagging_seg - 1 + tgt_len,
397+
device=attn_weights.device,
398+
)
399+
.clamp(1, src_len)
400+
.unsqueeze(0)
401+
.unsqueeze(2)
402+
.repeat(bsz, 1, 1)
403+
)
404+
405+
tmp = (
406+
torch.arange(0, src_len, device=attn_weights.device)
407+
.unsqueeze(0)
408+
.unsqueeze(1)
409+
.repeat(bsz, 1, 1)
410+
)
411+
412+
return tmp > idx
413+
391414
@staticmethod
392415
def _append_prev_key_padding_mask(
393416
key_padding_mask: Optional[Tensor],

0 commit comments

Comments
 (0)
Please sign in to comment.