Skip to content

[Fix] Tensorflow MASTER implementation #949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 1, 2022
4 changes: 2 additions & 2 deletions doctr/models/classification/magc_resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

default_cfgs: Dict[str, Dict[str, Any]] = {
'magc_resnet31': {
'mean': (0.5, 0.5, 0.5),
'std': (1., 1., 1.),
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'input_shape': (32, 32, 3),
'classes': list(VOCABS['french']),
'url': None,
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def forward(
return_preds: if True, decode logits

Returns:
A torch tensor, containing logits
A dictionnary containing eventually loss, logits and predictions.
"""

# Encode
Expand Down
103 changes: 59 additions & 44 deletions doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.models.classification import magc_resnet31

from ...utils.tensorflow import load_pretrained_params
from ..transformer.tensorflow import Decoder, create_look_ahead_mask, create_padding_mask, positional_encoding
from ..transformer.tensorflow import Decoder, PositionalEncoding
from .base import _MASTER, _MASTERPostProcessor

__all__ = ['MASTER', 'master']
Expand All @@ -24,8 +24,8 @@
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'input_shape': (32, 128, 3),
'vocab': VOCABS['legacy_french'],
'url': 'https://github.com/mindee/doctr/releases/download/v0.3.0/master-bade6eae.zip',
'vocab': VOCABS['french'],
'url': None,
},
}

Expand Down Expand Up @@ -58,38 +58,56 @@ def __init__(
num_layers: int = 3,
max_length: int = 50,
dropout: float = 0.2,
input_shape: Tuple[int, int, int] = (32, 128, 3),
input_shape: Tuple[int, int, int] = (32, 128, 3), # different from the paper
cfg: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()

self.vocab = vocab
self.max_length = max_length
self.d_model = d_model
self.vocab = vocab
self.cfg = cfg
self.vocab_size = len(vocab)

self.feat_extractor = feature_extractor
self.seq_embedding = layers.Embedding(self.vocab_size + 3, d_model) # 3 more classes: EOS/PAD/SOS
self.positional_encoding = PositionalEncoding(self.d_model, dropout, max_len=input_shape[0] * input_shape[1])

self.decoder = Decoder(
num_layers=num_layers,
d_model=d_model,
d_model=self.d_model,
num_heads=num_heads,
vocab_size=self.vocab_size + 3, # EOS, SOS, PAD
dff=dff,
vocab_size=self.vocab_size,
maximum_position_encoding=max_length,
dropout=dropout,
maximum_position_encoding=self.max_length,
)
self.feature_pe = positional_encoding(input_shape[0] * input_shape[1], d_model)
self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform())

self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform())
self.postprocessor = MASTERPostProcessor(vocab=self.vocab)

def make_mask(self, target: tf.Tensor) -> tf.Tensor:
look_ahead_mask = create_look_ahead_mask(tf.shape(target)[1])
target_padding_mask = create_padding_mask(target, self.vocab_size + 2) # Pad symbol
combined_mask = tf.maximum(target_padding_mask, look_ahead_mask)
return combined_mask
@tf.function
def make_source_and_target_mask(
self,
source: tf.Tensor,
target: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor]:
# [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
# (N, 1, 1, max_length)
target_pad_mask = tf.cast(tf.math.not_equal(target, self.vocab_size + 2), dtype=tf.uint8)
target_pad_mask = target_pad_mask[:, tf.newaxis, tf.newaxis, :]
target_length = target.shape[1]
# sub mask filled diagonal with 1 = see 0 = masked (max_length, max_length)
target_sub_mask = tf.linalg.band_part(tf.ones((target_length, target_length)), -1, 0)
# source mask filled with ones (max_length, positional_encoded_seq_len)
source_mask = tf.ones((target_length, source.shape[1]))
# combine the two masks into one (N, 1, max_length, max_length)
target_mask = tf.cast(
tf.math.logical_and(
tf.cast(target_sub_mask, dtype=tf.bool),
tf.cast(target_pad_mask, dtype=tf.bool)
), dtype=tf.uint8
)
return source_mask, target_mask

@staticmethod
def compute_loss(
Expand Down Expand Up @@ -147,27 +165,26 @@ def call(

# Encode
feature = self.feat_extractor(x, **kwargs)
b, h, w, c = (tf.shape(feature)[i] for i in range(4))
b, h, w, c = feature.get_shape()
# (N, H, W, C) --> (N, H * W, C)
feature = tf.reshape(feature, shape=(b, h * w, c))
encoded = feature + tf.cast(self.feature_pe[:, :h * w, :], dtype=feature.dtype)
# add positional encoding to features
encoded = self.positional_encoding(feature, **kwargs)

out: Dict[str, tf.Tensor] = {}

if kwargs.get('training', False) and target is None:
raise ValueError('Need to provide labels during training')

if target is not None:
# Compute target: tensor of gts and sequence lengths
gt, seq_len = self.build_target(target)

if kwargs.get('training', False):
if target is None:
raise ValueError("In training mode, you need to pass a value to 'target'")
tgt_mask = self.make_mask(gt)
# Compute decoder masks
source_mask, target_mask = self.make_source_and_target_mask(encoded, gt)
# Compute logits
output = self.decoder(gt, encoded, tgt_mask, None, **kwargs)
output = self.decoder(gt, encoded, source_mask, target_mask, **kwargs)
logits = self.linear(output, **kwargs)

else:
# When not training, we want to compute logits in with the decoder, although
# we have access to gts (we need gts to compute the loss, but not in the decoder)
logits = self.decode(encoded, **kwargs)

if target is not None:
Expand All @@ -177,11 +194,11 @@ def call(
out['out_map'] = logits

if return_preds:
predictions = self.postprocessor(logits)
out['preds'] = predictions
out['preds'] = self.postprocessor(logits)

return out

@tf.function
def decode(self, encoded: tf.Tensor, **kwargs: Any) -> tf.Tensor:
"""Decode function for prediction

Expand All @@ -191,30 +208,30 @@ def decode(self, encoded: tf.Tensor, **kwargs: Any) -> tf.Tensor:
Return:
A Tuple of tf.Tensor: predictions, logits
"""
b = tf.shape(encoded)[0]
max_len = tf.constant(self.max_length, dtype=tf.int32)
b = encoded.shape[0]

start_symbol = tf.constant(self.vocab_size + 1, dtype=tf.int32) # SOS
padding_symbol = tf.constant(self.vocab_size + 2, dtype=tf.int32) # PAD

ys = tf.fill(dims=(b, max_len - 1), value=padding_symbol)
ys = tf.fill(dims=(b, self.max_length - 1), value=padding_symbol)
start_vector = tf.fill(dims=(b, 1), value=start_symbol)
ys = tf.concat([start_vector, ys], axis=-1)

logits = tf.zeros(shape=(b, max_len - 1, self.vocab_size + 3), dtype=encoded.dtype) # 3 symbols
# max_len = len + 2 (sos + eos)
# Final dimension include EOS/SOS/PAD
for i in range(self.max_length - 1):
ys_mask = self.make_mask(ys)
output = self.decoder(ys, encoded, ys_mask, None, **kwargs)

source_mask, target_mask = self.make_source_and_target_mask(encoded, ys)
output = self.decoder(ys, encoded, source_mask, target_mask, **kwargs)
logits = self.linear(output, **kwargs)
prob = tf.nn.softmax(logits, axis=-1)
next_word = tf.argmax(prob, axis=-1, output_type=ys.dtype)
# ys.shape = B, T
i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_len), indexing='ij')
next_token = tf.argmax(prob, axis=-1, output_type=ys.dtype)
# update ys with the next token and ignore the first token (SOS)
i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(self.max_length), indexing='ij')
indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1)

ys = tf.tensor_scatter_nd_update(ys, indices, next_word[:, i + 1])
ys = tf.tensor_scatter_nd_update(ys, indices, next_token[:, i])

# final_logits of shape (N, max_length - 1, vocab_size + 1) (whithout sos)
# Shape (N, max_length, vocab_size + 1)
return logits


Expand All @@ -223,8 +240,6 @@ class MASTERPostProcessor(_MASTERPostProcessor):

Args:
vocab: string containing the ordered sequence of supported characters
ignore_case: if True, ignore case of letters
ignore_accents: if True, ignore accents of letters
"""

def __call__(
Expand Down Expand Up @@ -286,7 +301,7 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
>>> import tensorflow as tf
>>> from doctr.models import master
>>> model = master(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 48, 160, 3], maxval=1, dtype=tf.float32)
>>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)

Args:
Expand Down
1 change: 1 addition & 0 deletions doctr/models/recognition/transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,5 @@ def forward(
normed_output = self.layer_norm(output)
output = output + self.dropout(self.position_feed_forward[i](normed_output))

# (batch_size, seq_len, d_model)
return self.layer_norm(output)
Loading