Skip to content

Commit 2d8d834

Browse files
fix masking
1 parent 72262d4 commit 2d8d834

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

doctr/models/recognition/master/tensorflow.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,21 @@ def make_source_and_target_mask(
9191
source: tf.Tensor,
9292
target: tf.Tensor
9393
) -> Tuple[tf.Tensor, tf.Tensor]:
94-
# NOTE: inverse from PyTorch implementation
95-
# [0, 0, 0, ..., 1, 1, 1] -> 1 is masked
94+
# [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
9695
# (N, 1, 1, max_length)
97-
target_pad_mask = tf.cast(tf.math.equal(target, self.vocab_size + 2), dtype=tf.uint8)
96+
target_pad_mask = tf.cast(tf.math.not_equal(target, self.vocab_size + 2), dtype=tf.uint8)
9897
target_pad_mask = target_pad_mask[:, tf.newaxis, tf.newaxis, :]
9998
target_length = target.shape[1]
100-
# sub mask filled diagonal with 0 = see 1 = masked (max_length, max_length)
101-
target_sub_mask = 1 - tf.linalg.band_part(tf.ones((target_length, target_length)), -1, 0)
102-
# source mask filled with zeros (max_length, positional_encoded_seq_len)
103-
source_mask = tf.zeros((target_length, source.shape[1]))
99+
# sub mask filled diagonal with 1 = see 0 = masked (max_length, max_length)
100+
target_sub_mask = tf.linalg.band_part(tf.ones((target_length, target_length)), -1, 0)
101+
# source mask filled with ones (max_length, positional_encoded_seq_len)
102+
source_mask = tf.ones((target_length, source.shape[1]))
104103
# combine the two masks into one (N, 1, max_length, max_length)
105-
target_mask = tf.math.logical_and(
106-
tf.cast(target_sub_mask, dtype=tf.bool),
107-
tf.cast(target_pad_mask, dtype=tf.bool)
104+
target_mask = tf.cast(
105+
tf.math.logical_and(
106+
tf.cast(target_sub_mask, dtype=tf.bool),
107+
tf.cast(target_pad_mask, dtype=tf.bool)
108+
), dtype=tf.uint8
108109
)
109110
return source_mask, target_mask
110111

doctr/models/recognition/transformer/tensorflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def scaled_dot_product_attention(
6262

6363
scores = tf.matmul(query, key, transpose_b=True) / math.sqrt(query.shape[-1])
6464
if mask is not None:
65-
scores += (tf.cast(mask, dtype=query.dtype) * -1e9)
65+
scores = tf.where(mask == 0, -1e9, scores)
6666
p_attn = tf.nn.softmax(scores, axis=-1)
6767
return tf.matmul(p_attn, value), p_attn
6868

0 commit comments

Comments
 (0)