@@ -335,33 +335,26 @@ def infer_conv_output_dim(self, in_channels, input_dim, out_channels):
335
335
336
336
def buffered_future_mask (self , tensor ): # DP
337
337
dim = tensor .size (0 )
338
+ delay = self .encoder_mask_future_delay
339
+ block_size = self .encoder_mask_block_size
338
340
339
- if (self . encoder_mask_future_delay >= dim - 1 ): # Full attention allowed, no need to check other conditions
341
+ if (delay >= dim - 1 ) or ( block_size >= dim ): # Full attention allowed, no need to check other conditions
340
342
self ._future_mask = torch .zeros ([dim , dim ])
341
343
else : # Start with mask that disallows looking into future
342
344
tri_mask = torch .triu (
343
345
utils .fill_with_neg_inf (torch .zeros ([dim , dim ])), 1
344
346
)
345
347
346
- delay = self .encoder_mask_future_delay
347
- block_size = self .encoder_mask_block_size
348
- block_count = dim // block_size
349
- block_pad = dim % block_size
350
- blocks = torch .full ((block_count , block_size , block_size ), 1 , dtype = torch .bool )
351
-
352
348
# Create additional masks that consider self.encoder_mask_future_delay and self.encoder_mask_block_size
353
- block_mask = torch .nn .functional .pad (input = torch .block_diag (* blocks ), pad = (0 , block_pad , 0 , block_pad ))
349
+ block_count = math .ceil (dim / block_size )
350
+ blocks = torch .full ((block_count , block_size , block_size ), 1 , dtype = torch .bool )
351
+ block_mask = torch .nn .functional .pad (input = torch .block_diag (* blocks ), pad = (0 , 0 , 0 , 0 ))[:dim ,:dim ]
354
352
delay_mask = torch .cat (
355
353
(
356
354
torch .full ((dim ,delay + 1 ), 1 , dtype = torch .bool ),
357
355
torch .zeros ( (dim ,dim - (delay + 1 )), dtype = torch .bool )
358
356
), 1
359
357
)
360
-
361
- # VA, covers edge case where dim is less than block_size and the block_mask logic is a dimension off
362
- if dim < block_size :
363
- block_mask = block_mask [:- 1 ]
364
-
365
358
corr_mask = torch .logical_or (block_mask , delay_mask )
366
359
367
360
self ._future_mask = tri_mask .masked_fill_ (corr_mask , 0 ) # Apply correction
0 commit comments