Skip to content

[Fix] PyTorch MASTER implementation #941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 9, 2022

Conversation

felixdittrich92
Copy link
Contributor

@felixdittrich92 felixdittrich92 commented Jun 2, 2022

This PR: (Still in progress)

  • fix decoding step
  • fix Decoder

Any feedback is welcome 🤗

some notes at the moment:
different from paper input set to our default (32, 128)

ftm: transformer plain implementation not nn modules

but training and inference works fine 😅

@frgfm feel free to grab this and take a look 👍

Issue:
#802

@felixdittrich92 felixdittrich92 added this to the 0.5.2 milestone Jun 2, 2022
@felixdittrich92 felixdittrich92 added type: bug Something isn't working help wanted Extra attention is needed critical High priority module: models Related to doctr.models framework: pytorch Related to PyTorch backend topic: text recognition Related to the task of text recognition labels Jun 2, 2022
@felixdittrich92 felixdittrich92 changed the title [Fix] PyTorch MASTER implementation [WIP][Fix] PyTorch MASTER implementation Jun 2, 2022
@codecov
Copy link

codecov bot commented Jun 2, 2022

Codecov Report

Merging #941 (e0933c7) into main (75531c5) will increase coverage by 0.12%.
The diff coverage is 97.82%.

@@            Coverage Diff             @@
##             main     #941      +/-   ##
==========================================
+ Coverage   94.71%   94.83%   +0.12%     
==========================================
  Files         134      134              
  Lines        5501     5539      +38     
==========================================
+ Hits         5210     5253      +43     
+ Misses        291      286       -5     
Flag Coverage Δ
unittests 94.83% <97.82%> (+0.12%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
doctr/models/recognition/master/pytorch.py 95.23% <92.59%> (+3.49%) ⬆️
doctr/models/recognition/transformer/pytorch.py 100.00% <100.00%> (ø)
doctr/transforms/functional/base.py 97.10% <0.00%> (+1.44%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 75531c5...e0933c7. Read the comment docs.

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Jun 3, 2022

Todo:

  • ensure what is wrong with masking (part while it is currently not possible to use nn.TransformerDecoderLayer)
    (Note: after check masking (for nn version) and pos_encoded_embeddings several times i cannot explain what's different in there implementation -> revert to working one)
  • train and benchmark with toy set (500k MJSynth) FUNSD/CORD
  • documentation

1 Epoch 500K MJSynth split
Epoch 1/2 - Validation loss: 0.28989 (Exact: 54.65% | Partial: 57.52%)
FUNSD: Validation loss: 1.78504 (Exact: 39.78% | Partial: 41.77%)
CORD: Validation loss: 2.07195 (Exact: 36.71% | Partial: 37.39%)

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Jun 3, 2022

note: last commit replaces the own TransformerDecoder implementation with nn.TransformerDecoderLayer and breaks something

use 0f96095 as working and already tested fallback !!

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Jun 5, 2022

@frgfm i have some trouble bring it to work with the nn.TransformerDecoderLayer / target masking looks good, loss does also decrease but no way to reach any matches 😅

last working commit without nn.TransformerDecoderLayer
0f96095
also used for the toy run + benchmark

What do you think iterate to find a solution with PyTorch's transformer implementation (which would be definitly recommended) or use the own one (which works) ?

Update: I have tried to force the decoder (nn.TransformerDecoderLayer version) to "cheat" without masking but it results also in a non working model (stuck at ~2.8 loss no matches) i will revert it to the working version
Note: 7677e9a is last with nn.TransformerDecoderLayer

@felixdittrich92 felixdittrich92 changed the title [WIP][Fix] PyTorch MASTER implementation [Fix] PyTorch MASTER implementation Jun 7, 2022
@felixdittrich92 felixdittrich92 marked this pull request as ready for review June 7, 2022 10:44
Copy link
Collaborator

@charlesmindee charlesmindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@frgfm frgfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @felixdittrich92 🙏
Late comments for improvements :)

raise AssertionError("In training mode, you need to pass a value to 'target'")
tgt_mask = self.make_mask(gt)
# Compute source mask and target mask
source_mask, target_mask = self.make_source_and_target_mask(encoded, gt)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the source_mask is needed whatever the mode (training vs inference) but the target_mask is only needed in training right?

We need to avoid forgetting cases (being in training mode != being passed a target), especially for the evaluation loops for our reference scripts (the model will be in self.training = False but it will be given a target so that it can compute the loss)

Copy link
Contributor Author

@felixdittrich92 felixdittrich92 Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decoder only:
train:

  • we need the combined target mask to ensure the model cannot 'cheat'
    inference:
  • normally we don't need to pass any mask while inference ... but in fact that we pass a placeholder target we need the masking also while inference.. without the masks the model will not work correctly .. i can't explain 100% why but what i think is, that we also while inference need the masking to force the models predictions 'step by step' otherwise we end up in something like this: 'aaaaabbbb'
    (and i have tested the 500K MJSynth trained model (1 epoch) also in the whole pipe on an unseen document with a lot of text and got ~70% correct which was 🤯 )

Will update this 👍 (also crnn / sar pytorch)

@@ -206,17 +213,19 @@ def decode(self, encoded: torch.Tensor) -> torch.Tensor:
b = encoded.size(0)

# Padding symbol + SOS at the beginning
ys = torch.full((b, self.max_length), self.vocab_size + 2, dtype=torch.long, device=encoded.device)
ys[:, 0] = self.vocab_size + 1
ys = torch.zeros((b, self.max_length), dtype=torch.long, device=encoded.device) * (self.vocab_size + 2) # pad
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?
previously, the tensor was full of the value self.vocab_size + 2
now, it's full of zeros (multiplying 0 by self.vocab_size + 2 will yield zero 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh f...k got it 😅
In the end, it's really just about having the shape, it doesn't matter what values ​​are in there.
I also assumed that it would only make sense to fill with the PAD value, but that's irrelevant while inference.
But to make it make some sense, I'll change it 👍

@felixdittrich92 felixdittrich92 modified the milestones: 0.5.2, 0.6.0 Sep 26, 2022
@felixdittrich92 felixdittrich92 mentioned this pull request Sep 26, 2022
85 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
critical High priority framework: pytorch Related to PyTorch backend help wanted Extra attention is needed module: models Related to doctr.models topic: text recognition Related to the task of text recognition type: bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants