-
Notifications
You must be signed in to change notification settings - Fork 509
[Fix] PyTorch MASTER implementation #941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #941 +/- ##
==========================================
+ Coverage 94.71% 94.83% +0.12%
==========================================
Files 134 134
Lines 5501 5539 +38
==========================================
+ Hits 5210 5253 +43
+ Misses 291 286 -5
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
Todo:
1 Epoch 500K MJSynth split |
note: last commit replaces the own TransformerDecoder implementation with nn.TransformerDecoderLayer and breaks something use 0f96095 as working and already tested fallback !! |
@frgfm i have some trouble bring it to work with the nn.TransformerDecoderLayer / target masking looks good, loss does also decrease but no way to reach any matches 😅 last working commit without nn.TransformerDecoderLayer What do you think iterate to find a solution with PyTorch's transformer implementation (which would be definitly recommended) or use the own one (which works) ? Update: I have tried to force the decoder (nn.TransformerDecoderLayer version) to "cheat" without masking but it results also in a non working model (stuck at ~2.8 loss no matches) i will revert it to the working version |
7677e9a
to
e0933c7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @felixdittrich92 🙏
Late comments for improvements :)
raise AssertionError("In training mode, you need to pass a value to 'target'") | ||
tgt_mask = self.make_mask(gt) | ||
# Compute source mask and target mask | ||
source_mask, target_mask = self.make_source_and_target_mask(encoded, gt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the source_mask is needed whatever the mode (training vs inference) but the target_mask is only needed in training right?
We need to avoid forgetting cases (being in training mode != being passed a target), especially for the evaluation loops for our reference scripts (the model will be in self.training = False
but it will be given a target so that it can compute the loss)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decoder only:
train:
- we need the combined target mask to ensure the model cannot 'cheat'
inference: - normally we don't need to pass any mask while inference ... but in fact that we pass a placeholder target we need the masking also while inference.. without the masks the model will not work correctly .. i can't explain 100% why but what i think is, that we also while inference need the masking to force the models predictions 'step by step' otherwise we end up in something like this: 'aaaaabbbb'
(and i have tested the 500K MJSynth trained model (1 epoch) also in the whole pipe on an unseen document with a lot of text and got ~70% correct which was 🤯 )
Will update this 👍 (also crnn / sar pytorch)
@@ -206,17 +213,19 @@ def decode(self, encoded: torch.Tensor) -> torch.Tensor: | |||
b = encoded.size(0) | |||
|
|||
# Padding symbol + SOS at the beginning | |||
ys = torch.full((b, self.max_length), self.vocab_size + 2, dtype=torch.long, device=encoded.device) | |||
ys[:, 0] = self.vocab_size + 1 | |||
ys = torch.zeros((b, self.max_length), dtype=torch.long, device=encoded.device) * (self.vocab_size + 2) # pad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
previously, the tensor was full of the value self.vocab_size + 2
now, it's full of zeros (multiplying 0 by self.vocab_size + 2
will yield zero 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh f...k got it 😅
In the end, it's really just about having the shape, it doesn't matter what values are in there.
I also assumed that it would only make sense to fill with the PAD value, but that's irrelevant while inference.
But to make it make some sense, I'll change it 👍
This PR: (Still in progress)
Any feedback is welcome 🤗
some notes at the moment:
different from paper input set to our default (32, 128)
ftm: transformer plain implementation not nn modules
but training and inference works fine 😅
@frgfm feel free to grab this and take a look 👍
Issue:
#802