Ensures mask kwarg is used for output mask propagation #21449
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This fix addresses a bug in the Layer class where an explicitly passed
mask
keyword argument was ignored during the computation of the output mask for subsequent layers.The Problem
Currently, when a layer is called with an explicit mask (e.g.,
layer(inputs, mask=my_mask)
), the framework correctly uses this mask for the layer's internalcall()
computation. However, when preparing to propagate a mask to the next layer, the framework incorrectly ignores this explicit mask. Instead, it only checks for a_keras_mask
attribute attached directly to the input tensor.If the input tensor has no such attribute, the output mask is computed as
None
, effectively breaking the mask propagation chain.Solution
The fix modifies how the
previous_mask
variable is determined within theLayer.__call__
method. The logic has been updated to prioritize the explicitly passed mask keyword argument.The new order of operations is:
kwargs
. If yes, use it asprevious_mask
._keras_mask
attribute on the first input tensor.This ensures that the mask actually used by the layer is the same one used to compute the mask for its output, restoring correct and consistent mask propagation behavior.