-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(export+models): Enhance support for dictionary-based model input signatures in TensorFlow and JAX #20842
base: master
Are you sure you want to change the base?
fix(export+models): Enhance support for dictionary-based model input signatures in TensorFlow and JAX #20842
Conversation
…cases) - Improves input structure validation in Model and Functional classes - Adds strict validation with clear error messages for mismatches
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20842 +/- ##
==========================================
- Coverage 82.26% 82.20% -0.06%
==========================================
Files 561 561
Lines 52693 53035 +342
Branches 8146 8228 +82
==========================================
+ Hits 43347 43600 +253
- Misses 7344 7391 +47
- Partials 2002 2044 +42
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
- Make the linter happy :)
- Added shape logging for preprocessed and crossed features - Added debug messages in Functional model processing (To be removed)
TODO: Remove debugging prints. |
Ping: @mattdangerw @fchollet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! Left some comments.
Thanks for the review! Will fix it at the earliest. |
Please do look into the changes and let me know if it's the intended behavior. Looking forward to your guidance, thanks! |
Ping: @fchollet, looking forward to an update on this PR. Thanks! |
Thanks for the update!
Please resolve merge conflicts. @jeffcarp does the PR look good? |
@fchollet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % some nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@fchollet, just checking in on this PR. Since the changes are approved, is there anything else needed before merging? Looking forward to your feedback, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your work on this. I can see that a lot of effort was put into this. However, I believe this can be simplified quite a bit.
There is a lot of logic around creating, maintaining, restoring self._input_names
. However, it is only actually needed in export_utils.get_input_signature
, and it's only used for one purpose, which is to know the order of inputs. But the order of inputs is already known, right? It's in self._input_struct
. So you can just use self._input_struct
in export_utils.get_input_signature
and remove self._input_names
.
Additionally, I don't think dicts constitute a special case in most cases. Layers can have 1 or more positional arguments as inputs, and each one of those is a nested structure that can have dicts at any layer.
Out of curiosity, I added the following 2 tests in saved_model_test.py
, and they don't pass:
def test_export_with_two_dict_inputs_functional(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
inputs1 = {
"foo": layers.Input(shape=()),
"bar": layers.Input(shape=()),
}
inputs2 = {
"baz": layers.Input(shape=()),
}
outputs = layers.Add()([inputs1["foo"], inputs1["bar"], inputs2["baz"]])
model = models.Model((inputs1, inputs2), outputs)
ref_input = (
{"foo": tf.constant([1.0]), "bar": tf.constant([2.0])},
{"baz": tf.constant([2.0])},
)
ref_output = model(ref_input)
model.export(temp_filepath, format="tf_saved_model")
revived_model = tf.saved_model.load(temp_filepath)
revived_output = revived_model.serve(ref_input)
self.assertAllClose(ref_output, revived_output)
def test_export_with_two_dict_inputs_subclass(self):
class TwoDictModel(models.Model):
def call(self, inputs1, inputs2):
return ops.add(
ops.add(inputs1["foo"], inputs1["bar"]), inputs2["baz"]
)
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = TwoDictModel()
ref_input1 = {"foo": tf.constant([1.0]), "bar": tf.constant([2.0])}
ref_input2 = {"baz": tf.constant([2.0])}
ref_output = model(ref_input1, ref_input2)
model.export(temp_filepath, format="tf_saved_model")
revived_model = tf.saved_model.load(temp_filepath)
revived_output = revived_model.serve(ref_input1, ref_input2)
self.assertAllClose(ref_output, revived_output)
This shows that dicts as second arguments are not supported.
# Create a simplified wrapper that handles both dict and | ||
# positional args, similar to TensorFlow implementation. | ||
def wrapped_fn(arg, **kwargs): | ||
return fn(arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following
- why this is needed
- what tensorflow implementation it follows
All it does is keep the first position argument and drop everything else. What if you have multiple positional arguments?
jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape( | ||
input_signature | ||
) | ||
if isinstance(input_signature, dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be needed. input_signature
should always be a list because the model can have multiple positional arguments. If the inputs are a dict, input_signature
should be a list of one dict already. If you need to add the outer list here, it means that an incorrect input_signature
was passed in the first place.
decorated_fn = tf.function( | ||
fn, input_signature=input_signature, autograph=False | ||
) | ||
if isinstance(input_signature, dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment about why this shouldn't be needed. input_signature
should always be a list because the model can have multiple positional arguments. If the inputs are a dict, input_signature
should be a list of one dict already. If you need to add the outer list here, it means that an incorrect input_signature
was passed in the first place.
if isinstance(input_signature, dict): | ||
# Create a simplified wrapper that handles both dict and | ||
# positional args. | ||
def wrapped_fn(arg, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment about why this is needed and dropping all but the first positional argument.
tensor_spec = x | ||
else: | ||
return x | ||
if isinstance(x, dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make_tf_tensor_spec
is always called (and should always be called) within keras.tree.map_structure
, so you'll never handles dicts here. Remove this case.
It is indeed not covered, showing it's not used:
https://app.codecov.io/gh/keras-team/keras/pull/20842?src=pr&el=tree&filepath=keras%2Fsrc%2Fexport%2Fexport_utils.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#4690cab8cef00d17f2a67990022886ff-R103
@@ -399,6 +413,15 @@ def quantize(self, mode, **kwargs): | |||
def build_from_config(self, config): | |||
if not config: | |||
return | |||
# Fetch the input structure from config if available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 416-424 is all dead code. "input_names"
is never in the config because nothing adds it there. Remove.
@@ -204,6 +213,11 @@ def get_layer(self, name=None, index=None): | |||
return self.layers[index] | |||
|
|||
if name is not None: | |||
# Check if the name matches any of the input names. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is needed, the fallback should work for this use case.
Also, coverage shows it's never used:
https://app.codecov.io/gh/keras-team/keras/pull/20842?src=pr&el=tree&filepath=keras%2Fsrc%2Fmodels%2Fmodel.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#b343e88c9bde68f6f2b9d00f13fa225e-R218
@@ -156,6 +156,15 @@ def __init__(self, *args, **kwargs): | |||
functional.Functional.__init__(self, *args, **kwargs) | |||
else: | |||
Layer.__init__(self, *args, **kwargs) | |||
if args: | |||
inputs = args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes there is only one input argument. That's true for Functional models, that is not necessarily true for subclass model.
@@ -156,6 +156,15 @@ def __init__(self, *args, **kwargs): | |||
functional.Functional.__init__(self, *args, **kwargs) | |||
else: | |||
Layer.__init__(self, *args, **kwargs) | |||
if args: | |||
inputs = args[0] | |||
# Only set _input_names if not already initialized by Functional. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're in an else, you know for sure it's not a Functional model if you're here.
if isinstance(input_signature, list) and len(input_signature) > 1: | ||
input_signature = [input_signature] | ||
if hasattr(model, "_input_names") and model._input_names: | ||
if isinstance(model._inputs_struct, dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would _inputs_struct
not have the right order?
This is lovely! Thanks for sharing your insights, @hertschuh. I believe I was missing the fact that I could make much better use of the existing |
This fix goes beyond the requirements of the issue and adds support for handling Keras models with dictionary-based inputs, particularly when exporting to the
TFSavedModel
format for both the TensorFlow and JAX backends. Previously, models with dictionary inputs would fail during export with ValueErrors related to input structure mismatches.Key changes:
model._input_names
for dictionary-based inputs in Functional and Model classesThis PR aims to fix #20835 where models with dictionary inputs would fail to export properly to
SavedModel
format.Example of fixed functionality: