Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(export+models): Enhance support for dictionary-based model input signatures in TensorFlow and JAX #20842

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

harshaljanjani
Copy link
Contributor

@harshaljanjani harshaljanjani commented Feb 2, 2025

This fix goes beyond the requirements of the issue and adds support for handling Keras models with dictionary-based inputs, particularly when exporting to the TFSavedModel format for both the TensorFlow and JAX backends. Previously, models with dictionary inputs would fail during export with ValueErrors related to input structure mismatches.

Key changes:

  • Added proper handling of model._input_names for dictionary-based inputs in Functional and Model classes
  • Enhanced input signature generation to properly handle dictionary input structures
  • Added test coverage for dictionary input model export
  • Added debug logging throughout the input signature handling flow

This PR aims to fix #20835 where models with dictionary inputs would fail to export properly to SavedModel format.

Example of fixed functionality:

inputs = {
  "foo": layers.Input(shape=()),
  "bar": layers.Input(shape=()),
}
outputs = layers.Add()([inputs["foo"], inputs["bar"]])
model = models.Model(inputs, outputs)
ref_input = {"foo": tf.constant([1.0]), "bar": tf.constant([2.0])}
ref_output = model(ref_input)
model.export('/test', format="tf_saved_model")
revived_model = tf.saved_model.load(temp_filepath)
revived_output = revived_model.serve(ref_input)

…cases)

- Improves input structure validation in Model and Functional classes

- Adds strict validation with clear error messages for mismatches
@codecov-commenter
Copy link

codecov-commenter commented Feb 3, 2025

Codecov Report

Attention: Patch coverage is 59.61538% with 42 lines in your changes missing coverage. Please review.

Project coverage is 82.20%. Comparing base (e045b6a) to head (72371d7).
Report is 20 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/models/functional.py 55.76% 12 Missing and 11 partials ⚠️
keras/src/models/model.py 21.05% 10 Missing and 5 partials ⚠️
keras/src/export/export_utils.py 71.42% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20842      +/-   ##
==========================================
- Coverage   82.26%   82.20%   -0.06%     
==========================================
  Files         561      561              
  Lines       52693    53035     +342     
  Branches     8146     8228      +82     
==========================================
+ Hits        43347    43600     +253     
- Misses       7344     7391      +47     
- Partials     2002     2044      +42     
Flag Coverage Δ
keras 82.02% <59.61%> (-0.06%) ⬇️
keras-jax 64.14% <54.80%> (-0.11%) ⬇️
keras-numpy 58.89% <23.07%> (-0.10%) ⬇️
keras-openvino 32.36% <1.92%> (-0.19%) ⬇️
keras-tensorflow 64.60% <46.15%> (-0.24%) ⬇️
keras-torch 64.16% <39.42%> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

- Added shape logging for preprocessed and crossed features

- Added debug messages in Functional model processing (To be removed)
@harshaljanjani
Copy link
Contributor Author

harshaljanjani commented Feb 3, 2025

TODO: Remove debugging prints.
Edit: Done and ready for review.

@harshaljanjani harshaljanjani marked this pull request as ready for review February 3, 2025 19:54
@harshaljanjani
Copy link
Contributor Author

Ping: @mattdangerw @fchollet.

Copy link
Member

@jeffcarp jeffcarp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Left some comments.

@harshaljanjani
Copy link
Contributor Author

Thanks for the review! Will fix it at the earliest.

@harshaljanjani
Copy link
Contributor Author

Please do look into the changes and let me know if it's the intended behavior. Looking forward to your guidance, thanks!

@harshaljanjani harshaljanjani changed the title fix(export+models): Enhance input signature handling for dictionary-based models fix(export+models): Enhance support for dictionary-based model input signatures in TensorFlow and JAX Feb 9, 2025
@harshaljanjani
Copy link
Contributor Author

Ping: @fchollet, looking forward to an update on this PR. Thanks!

@fchollet
Copy link
Collaborator

Thanks for the update!

This branch has conflicts that must be resolved

Please resolve merge conflicts.

@jeffcarp does the PR look good?

@harshaljanjani
Copy link
Contributor Author

@fchollet
Resolved merge conflicts, thanks!

Copy link
Member

@jeffcarp jeffcarp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % some nits

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Feb 18, 2025
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Feb 19, 2025
@harshaljanjani
Copy link
Contributor Author

@fchollet @jeffcarp
Fixed! Thanks for reviewing the PR.

Copy link
Member

@jeffcarp jeffcarp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Feb 19, 2025
@harshaljanjani
Copy link
Contributor Author

harshaljanjani commented Feb 22, 2025

@fchollet, just checking in on this PR. Since the changes are approved, is there anything else needed before merging? Looking forward to your feedback, thanks!

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your work on this. I can see that a lot of effort was put into this. However, I believe this can be simplified quite a bit.

There is a lot of logic around creating, maintaining, restoring self._input_names. However, it is only actually needed in export_utils.get_input_signature, and it's only used for one purpose, which is to know the order of inputs. But the order of inputs is already known, right? It's in self._input_struct. So you can just use self._input_struct in export_utils.get_input_signature and remove self._input_names.

Additionally, I don't think dicts constitute a special case in most cases. Layers can have 1 or more positional arguments as inputs, and each one of those is a nested structure that can have dicts at any layer.

Out of curiosity, I added the following 2 tests in saved_model_test.py, and they don't pass:

    def test_export_with_two_dict_inputs_functional(self):
        temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
        inputs1 = {
            "foo": layers.Input(shape=()),
            "bar": layers.Input(shape=()),
        }
        inputs2 = {
            "baz": layers.Input(shape=()),
        }
        outputs = layers.Add()([inputs1["foo"], inputs1["bar"], inputs2["baz"]])
        model = models.Model((inputs1, inputs2), outputs)
        ref_input = (
            {"foo": tf.constant([1.0]), "bar": tf.constant([2.0])},
            {"baz": tf.constant([2.0])},
        )
        ref_output = model(ref_input)
        model.export(temp_filepath, format="tf_saved_model")
        revived_model = tf.saved_model.load(temp_filepath)
        revived_output = revived_model.serve(ref_input)
        self.assertAllClose(ref_output, revived_output)

    def test_export_with_two_dict_inputs_subclass(self):
        class TwoDictModel(models.Model):
            def call(self, inputs1, inputs2):
                return ops.add(
                    ops.add(inputs1["foo"], inputs1["bar"]), inputs2["baz"]
                )

        temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
        model = TwoDictModel()
        ref_input1 = {"foo": tf.constant([1.0]), "bar": tf.constant([2.0])}
        ref_input2 = {"baz": tf.constant([2.0])}
        ref_output = model(ref_input1, ref_input2)
        model.export(temp_filepath, format="tf_saved_model")
        revived_model = tf.saved_model.load(temp_filepath)
        revived_output = revived_model.serve(ref_input1, ref_input2)
        self.assertAllClose(ref_output, revived_output)

This shows that dicts as second arguments are not supported.

# Create a simplified wrapper that handles both dict and
# positional args, similar to TensorFlow implementation.
def wrapped_fn(arg, **kwargs):
return fn(arg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following

  • why this is needed
  • what tensorflow implementation it follows

All it does is keep the first position argument and drop everything else. What if you have multiple positional arguments?

jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape(
input_signature
)
if isinstance(input_signature, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be needed. input_signature should always be a list because the model can have multiple positional arguments. If the inputs are a dict, input_signature should be a list of one dict already. If you need to add the outer list here, it means that an incorrect input_signature was passed in the first place.

decorated_fn = tf.function(
fn, input_signature=input_signature, autograph=False
)
if isinstance(input_signature, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about why this shouldn't be needed. input_signature should always be a list because the model can have multiple positional arguments. If the inputs are a dict, input_signature should be a list of one dict already. If you need to add the outer list here, it means that an incorrect input_signature was passed in the first place.

if isinstance(input_signature, dict):
# Create a simplified wrapper that handles both dict and
# positional args.
def wrapped_fn(arg, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about why this is needed and dropping all but the first positional argument.

tensor_spec = x
else:
return x
if isinstance(x, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_tf_tensor_spec is always called (and should always be called) within keras.tree.map_structure, so you'll never handles dicts here. Remove this case.

It is indeed not covered, showing it's not used:
https://app.codecov.io/gh/keras-team/keras/pull/20842?src=pr&el=tree&filepath=keras%2Fsrc%2Fexport%2Fexport_utils.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#4690cab8cef00d17f2a67990022886ff-R103

@@ -399,6 +413,15 @@ def quantize(self, mode, **kwargs):
def build_from_config(self, config):
if not config:
return
# Fetch the input structure from config if available.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -204,6 +213,11 @@ def get_layer(self, name=None, index=None):
return self.layers[index]

if name is not None:
# Check if the name matches any of the input names.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -156,6 +156,15 @@ def __init__(self, *args, **kwargs):
functional.Functional.__init__(self, *args, **kwargs)
else:
Layer.__init__(self, *args, **kwargs)
if args:
inputs = args[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes there is only one input argument. That's true for Functional models, that is not necessarily true for subclass model.

@@ -156,6 +156,15 @@ def __init__(self, *args, **kwargs):
functional.Functional.__init__(self, *args, **kwargs)
else:
Layer.__init__(self, *args, **kwargs)
if args:
inputs = args[0]
# Only set _input_names if not already initialized by Functional.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're in an else, you know for sure it's not a Functional model if you're here.

if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
if hasattr(model, "_input_names") and model._input_names:
if isinstance(model._inputs_struct, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would _inputs_struct not have the right order?

@harshaljanjani
Copy link
Contributor Author

This is lovely! Thanks for sharing your insights, @hertschuh. I believe I was missing the fact that I could make much better use of the existing _input_struct. I’ll try to understand the problem better based on your comments and make the necessary changes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review ready to pull Ready to be merged into the codebase size:M
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ValueErrors when calling Model.export() for TF SavedModel format on Keras Models with dict inputs
7 participants