Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(export+models): Enhance support for dictionary-based model input signatures in TensorFlow and JAX #20842

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions keras/src/backend/jax/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,52 @@ def add_endpoint(self, name, fn, input_signature=None, **kwargs):
jax2tf_kwargs["native_serialization"] = (
self._check_device_compatible()
)

# When input_signature is a dict, we need to
# adjust polymorphic shapes.
if "polymorphic_shapes" not in jax2tf_kwargs:
jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape(
input_signature
)
if isinstance(input_signature, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be needed. input_signature should always be a list because the model can have multiple positional arguments. If the inputs are a dict, input_signature should be a list of one dict already. If you need to add the outer list here, it means that an incorrect input_signature was passed in the first place.

# Wrap the shape in a list to match the input structure
jax2tf_kwargs["polymorphic_shapes"] = [
self._to_polymorphic_shape(input_signature)
]
else:
jax2tf_kwargs["polymorphic_shapes"] = (
self._to_polymorphic_shape(input_signature)
)

# Note: we truncate the number of parameters to what is specified by
# `input_signature`.
fn_signature = inspect.signature(fn)
fn_parameters = list(fn_signature.parameters.values())

if isinstance(input_signature, dict):
# Create a simplified wrapper that handles both dict and
# positional args, similar to TensorFlow implementation.
def wrapped_fn(arg, **kwargs):
return fn(arg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following

  • why this is needed
  • what tensorflow implementation it follows

All it does is keep the first position argument and drop everything else. What if you have multiple positional arguments?


target_fn = wrapped_fn
target_signature = [input_signature]
target_params = [fn_parameters[0]]
else:
# Original code path for non-dict input signatures.
target_fn = fn
target_signature = input_signature
target_params = fn_parameters[0 : len(input_signature)]

if is_static:
from jax.experimental import jax2tf

jax_fn = jax2tf.convert(fn, **jax2tf_kwargs)
jax_fn = jax2tf.convert(target_fn, **jax2tf_kwargs)
jax_fn.__signature__ = inspect.Signature(
parameters=fn_parameters[0 : len(input_signature)],
parameters=target_params,
return_annotation=fn_signature.return_annotation,
)

decorated_fn = tf.function(
jax_fn,
input_signature=input_signature,
input_signature=target_signature,
autograph=False,
)
else:
Expand All @@ -85,7 +109,7 @@ def add_endpoint(self, name, fn, input_signature=None, **kwargs):
def stateless_fn(variables, *args, **kwargs):
state_mapping = zip(self._backend_variables, variables)
with StatelessScope(state_mapping=state_mapping) as scope:
output = fn(*args, **kwargs)
output = target_fn(*args, **kwargs)

# Gather updated non-trainable variables
non_trainable_variables = []
Expand All @@ -95,7 +119,7 @@ def stateless_fn(variables, *args, **kwargs):
return output, non_trainable_variables

jax2tf_stateless_fn = self._convert_jax2tf_function(
stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs
stateless_fn, target_signature, jax2tf_kwargs=jax2tf_kwargs
)

def stateful_fn(*args, **kwargs):
Expand All @@ -113,13 +137,13 @@ def stateful_fn(*args, **kwargs):
return output

stateful_fn.__signature__ = inspect.Signature(
parameters=fn_parameters[0 : len(input_signature)],
parameters=target_params,
return_annotation=fn_signature.return_annotation,
)

decorated_fn = tf.function(
stateful_fn,
input_signature=input_signature,
input_signature=target_signature,
autograph=False,
)
return decorated_fn
Expand Down
18 changes: 15 additions & 3 deletions keras/src/backend/tensorflow/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,19 @@ def _track_layer(self, layer):
self._tf_trackable.non_trainable_variables += non_trainable_variables

def add_endpoint(self, name, fn, input_signature=None, **kwargs):
decorated_fn = tf.function(
fn, input_signature=input_signature, autograph=False
)
if isinstance(input_signature, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about why this shouldn't be needed. input_signature should always be a list because the model can have multiple positional arguments. If the inputs are a dict, input_signature should be a list of one dict already. If you need to add the outer list here, it means that an incorrect input_signature was passed in the first place.

# Create a simplified wrapper that handles both dict and
# positional args.
def wrapped_fn(arg, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about why this is needed and dropping all but the first positional argument.

return fn(arg)

decorated_fn = tf.function(
wrapped_fn,
input_signature=[input_signature],
autograph=False,
)
else:
decorated_fn = tf.function(
fn, input_signature=input_signature, autograph=False
)
return decorated_fn
49 changes: 41 additions & 8 deletions keras/src/export/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,20 @@ def get_input_signature(model):
"before export."
)
if isinstance(model, (models.Functional, models.Sequential)):
input_signature = tree.map_structure(make_input_spec, model.inputs)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
if hasattr(model, "_input_names") and model._input_names:
if isinstance(model._inputs_struct, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would _inputs_struct not have the right order?

# Create dictionary input signature while
# preserving order.
input_signature = {
name: make_input_spec(tensor)
for name, tensor in zip(model._input_names, model.inputs)
}
else:
input_signature = tree.map_structure(
make_input_spec, model.inputs
)
else:
input_signature = tree.map_structure(make_input_spec, model.inputs)
else:
input_signature = _infer_input_signature_from_model(model)
if not input_signature or not model._called:
Expand Down Expand Up @@ -86,13 +97,35 @@ def make_input_spec(x):

def make_tf_tensor_spec(x):
if isinstance(x, tf.TensorSpec):
tensor_spec = x
else:
return x
if isinstance(x, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_tf_tensor_spec is always called (and should always be called) within keras.tree.map_structure, so you'll never handles dicts here. Remove this case.

It is indeed not covered, showing it's not used:
https://app.codecov.io/gh/keras-team/keras/pull/20842?src=pr&el=tree&filepath=keras%2Fsrc%2Fexport%2Fexport_utils.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#4690cab8cef00d17f2a67990022886ff-R103

# Convert dict to ordered list with names preserved.
return {
name: tf.TensorSpec(
shape=input_spec.shape,
dtype=input_spec.dtype,
name=name,
)
for name, spec in x.items()
for input_spec in [make_input_spec(spec)]
}
elif isinstance(x, layers.InputSpec):
input_spec = make_input_spec(x)
tensor_spec = tf.TensorSpec(
input_spec.shape, dtype=input_spec.dtype, name=input_spec.name
return tf.TensorSpec(
shape=input_spec.shape, dtype=input_spec.dtype, name=input_spec.name
)
else:
if hasattr(x, "shape") and hasattr(x, "dtype"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turn else: if into an elif

input_spec = make_input_spec(x)
return tf.TensorSpec(
shape=input_spec.shape,
dtype=input_spec.dtype,
name=getattr(input_spec, "name", None),
)
raise TypeError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an else here to make the if / elif / else flow more obvious.

f"Unsupported x={x} of the type ({type(x)}). Supported types are: "
"`keras.InputSpec`, `keras.KerasTensor` and backend tensor."
)
return tensor_spec


def convert_spec_to_tensor(spec, replace_none_number=None):
Expand Down
15 changes: 15 additions & 0 deletions keras/src/export/saved_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,21 @@ def test_model_export_method(self, model_type):
# Test with a different batch size
revived_model.serve(tf.random.normal((6, 10)))

def test_export_with_dict_input(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
inputs = {
"foo": layers.Input(shape=()),
"bar": layers.Input(shape=()),
}
outputs = layers.Add()([inputs["foo"], inputs["bar"]])
model = models.Model(inputs, outputs)
ref_input = {"foo": tf.constant([1.0]), "bar": tf.constant([2.0])}
ref_output = model(ref_input)
model.export(temp_filepath, format="tf_saved_model")
revived_model = tf.saved_model.load(temp_filepath)
revived_output = revived_model.serve(ref_input)
self.assertAllClose(ref_output, revived_output)

def test_model_combined_with_tf_preprocessing(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")

Expand Down
144 changes: 112 additions & 32 deletions keras/src/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def __new__(cls, *args, **kwargs):
@tracking.no_automatic_dependency_tracking
def __init__(self, inputs, outputs, name=None, **kwargs):
if isinstance(inputs, dict):
# This implementation relies on the deterministic order of
# dictionary keys
# Assumption from Python < 3.7.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about this comment Assumption from Python < 3.7. I thought we were assuming exactly the opposite.

self._input_names = list(inputs.keys())
self._inputs_struct = inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already done by Function.__init__ line 144. In fact Function.__init__ will overwrite what you set here.

for k, v in inputs.items():
if isinstance(v, backend.KerasTensor) and k != v.name:
warnings.warn(
Expand All @@ -111,6 +116,9 @@ def __init__(self, inputs, outputs, name=None, **kwargs):
f"which has name '{v.name}'. Change the tensor name to "
f"'{k}' (via `Input(..., name='{k}')`)"
)
else:
self._input_names = None
self._inputs_struct = inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already done by Function.__init__ line 144. In fact Function.__init__ will overwrite what you set here.


trainable = kwargs.pop("trainable", None)
flat_inputs = tree.flatten(inputs)
Expand Down Expand Up @@ -267,7 +275,8 @@ def _adjust_input_rank(self, flat_inputs):
adjusted.append(ops.squeeze(x, axis=-1))
continue
if x_rank == ref_rank - 1:
if ref_shape[-1] == 1:
# Check if ref_shape's last dimension is None (variable) or 1.
if ref_shape[-1] is None or ref_shape[-1] == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this change needed? It seems risky to assume that None will actually be 1 once concrete tensors are passed.

adjusted.append(ops.expand_dims(x, axis=-1))
continue
raise ValueError(
Expand All @@ -284,42 +293,113 @@ def _adjust_input_rank(self, flat_inputs):
return adjusted

def _standardize_inputs(self, inputs):
raise_exception = False
if isinstance(inputs, dict) and not isinstance(
self._inputs_struct, dict
if inputs is None:
raise ValueError(
"No inputs provided to the model (inputs is None)."
)

if self._inputs_struct is None:
warnings.warn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, this should be a raise an error in __init__ instead to make sure this situation never happens.

"Model's input structure (_inputs_struct) is None."
" This may lead to unexpected behavior.",
UserWarning,
)

# Additional check: if both inputs and _inputs_struct are
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 308 to 343 duplicates what keras.tree.assert_same_structure() does. Except that this code is not recursive (inputs can be nested) and keras.tree.assert_same_structure() is. So we should use it as much as possible.

# lists/tuples, warn if lengths differ.
if isinstance(inputs, (list, tuple)) and isinstance(
self._inputs_struct, (list, tuple)
):
# This is to avoid warning
# when we have reconciable dict/list structs
if hasattr(self._inputs_struct, "__len__") and all(
isinstance(i, backend.KerasTensor) for i in self._inputs_struct
if len(inputs) != len(self._inputs_struct):
warnings.warn(
f"Number of inputs ({len(inputs)}) does not match"
f" the expected number ({len(self._inputs_struct)})"
" based on self._inputs_struct. This may lead to unexpected"
" behavior.",
UserWarning,
)

if isinstance(inputs, dict) and isinstance(self._inputs_struct, dict):
model_keys = set(self._inputs_struct.keys())
input_keys = set(inputs.keys())
missing = model_keys - input_keys
if missing:
raise ValueError(
f"Input keys don't match model input keys. "
f"Model expects: {list(self._inputs_struct.keys())}, "
f"but got: {list(inputs.keys())}"
)
extra = input_keys - model_keys
if extra:
self._maybe_warn_inputs_struct_mismatch(
inputs, raise_exception=False
)
filtered_inputs = {k: inputs[k] for k in model_keys}
converted_inputs = tree.map_structure(
ops.convert_to_tensor, filtered_inputs
)
# Flatten according to model's input structure to process ranks.
flat_inputs = tree.flatten(converted_inputs)
inputs = flat_inputs
else:
# Revert back to old logic for non-dict inputs.
raise_exception = False
if isinstance(inputs, dict) and not isinstance(
self._inputs_struct, dict
):
expected_keys = set(i.name for i in self._inputs_struct)
keys = set(inputs.keys())
if expected_keys.issubset(keys):
inputs = [inputs[i.name] for i in self._inputs_struct]
else:
raise_exception = True
elif isinstance(self._inputs_struct, backend.KerasTensor):
if self._inputs_struct.name in inputs:
inputs = [inputs[self._inputs_struct.name]]
if isinstance(self._inputs_struct, (list, tuple)):
all_kt = all(
isinstance(kt, backend.KerasTensor)
for kt in self._inputs_struct
)
if all_kt:
expected_names = [kt.name for kt in self._inputs_struct]
input_keys = set(inputs.keys())
if input_keys.issuperset(expected_names):
inputs = [inputs[name] for name in expected_names]
else:
missing = set(expected_names) - input_keys
raise ValueError(
f"Missing input keys: {missing}. "
f"Expected keys: {expected_names}, "
f"received keys: {list(inputs.keys())}"
)
else:
raise_exception = True
elif isinstance(self._inputs_struct, backend.KerasTensor):
if self._inputs_struct.name in inputs:
inputs = inputs[self._inputs_struct.name]
else:
raise_exception = True
else:
raise_exception = True
else:
raise_exception = True
if (
isinstance(self._inputs_struct, dict)
and not isinstance(inputs, dict)
and list(self._inputs_struct.keys())
!= sorted(self._inputs_struct.keys())
):
raise_exception = True
self._maybe_warn_inputs_struct_mismatch(
inputs, raise_exception=raise_exception
)
elif not isinstance(inputs, (list, tuple)) and isinstance(
self._inputs_struct, (list, tuple)
):
inputs = [inputs]

flat_inputs = tree.flatten(inputs)
flat_inputs = self._convert_inputs_to_tensors(flat_inputs)
return self._adjust_input_rank(flat_inputs)
self._maybe_warn_inputs_struct_mismatch(
inputs, raise_exception=raise_exception
)

if raise_exception:
raise ValueError(
f"The model's input structure doesn't match the input data."
f" Model expects {self._inputs_struct},"
f" but received {inputs}."
)

try:
inputs = tree.flatten(inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I liked the use of the name flat_inputs here, it makes in clear what's going on and matches what's done in other parts of the code.

inputs = self._convert_inputs_to_tensors(inputs)
except:
raise ValueError(
f"The model's input structure doesn't match the input data."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never fail (unless you're passing objects that can't be converted to tensors.

And if it does fail, I think the original error message would be most useful. So no need to catch here.

Also, I believe the error message is incorrect. It's not an input structure problem if it fails.

f" Model expects {self._inputs_struct},"
f" but received {inputs}."
)

return self._adjust_input_rank(inputs)

@property
def input(self):
Expand Down
Loading