@@ -91,20 +91,21 @@ def make_source_and_target_mask(
91
91
source : tf .Tensor ,
92
92
target : tf .Tensor
93
93
) -> Tuple [tf .Tensor , tf .Tensor ]:
94
- # NOTE: inverse from PyTorch implementation
95
- # [0, 0, 0, ..., 1, 1, 1] -> 1 is masked
94
+ # [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
96
95
# (N, 1, 1, max_length)
97
- target_pad_mask = tf .cast (tf .math .equal (target , self .vocab_size + 2 ), dtype = tf .uint8 )
96
+ target_pad_mask = tf .cast (tf .math .not_equal (target , self .vocab_size + 2 ), dtype = tf .uint8 )
98
97
target_pad_mask = target_pad_mask [:, tf .newaxis , tf .newaxis , :]
99
98
target_length = target .shape [1 ]
100
- # sub mask filled diagonal with 0 = see 1 = masked (max_length, max_length)
101
- target_sub_mask = 1 - tf .linalg .band_part (tf .ones ((target_length , target_length )), - 1 , 0 )
102
- # source mask filled with zeros (max_length, positional_encoded_seq_len)
103
- source_mask = tf .zeros ((target_length , source .shape [1 ]))
99
+ # sub mask filled diagonal with 1 = see 0 = masked (max_length, max_length)
100
+ target_sub_mask = tf .linalg .band_part (tf .ones ((target_length , target_length )), - 1 , 0 )
101
+ # source mask filled with ones (max_length, positional_encoded_seq_len)
102
+ source_mask = tf .ones ((target_length , source .shape [1 ]))
104
103
# combine the two masks into one (N, 1, max_length, max_length)
105
- target_mask = tf .math .logical_and (
106
- tf .cast (target_sub_mask , dtype = tf .bool ),
107
- tf .cast (target_pad_mask , dtype = tf .bool )
104
+ target_mask = tf .cast (
105
+ tf .math .logical_and (
106
+ tf .cast (target_sub_mask , dtype = tf .bool ),
107
+ tf .cast (target_pad_mask , dtype = tf .bool )
108
+ ), dtype = tf .uint8
108
109
)
109
110
return source_mask , target_mask
110
111
0 commit comments