-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(export+models): Enhance support for dictionary-based model input signatures in TensorFlow and JAX #20842
base: master
Are you sure you want to change the base?
Changes from all commits
5a84fc5
d903b41
9e38bd8
1d969c3
299b4ba
0c65c02
52d7fb3
18adae3
5178e44
5de08b4
b74a295
4ed5607
0659804
c32c02c
9e7f3c8
72371d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,28 +50,52 @@ def add_endpoint(self, name, fn, input_signature=None, **kwargs): | |
jax2tf_kwargs["native_serialization"] = ( | ||
self._check_device_compatible() | ||
) | ||
|
||
# When input_signature is a dict, we need to | ||
# adjust polymorphic shapes. | ||
if "polymorphic_shapes" not in jax2tf_kwargs: | ||
jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape( | ||
input_signature | ||
) | ||
if isinstance(input_signature, dict): | ||
# Wrap the shape in a list to match the input structure | ||
jax2tf_kwargs["polymorphic_shapes"] = [ | ||
self._to_polymorphic_shape(input_signature) | ||
] | ||
else: | ||
jax2tf_kwargs["polymorphic_shapes"] = ( | ||
self._to_polymorphic_shape(input_signature) | ||
) | ||
|
||
# Note: we truncate the number of parameters to what is specified by | ||
# `input_signature`. | ||
fn_signature = inspect.signature(fn) | ||
fn_parameters = list(fn_signature.parameters.values()) | ||
|
||
if isinstance(input_signature, dict): | ||
# Create a simplified wrapper that handles both dict and | ||
# positional args, similar to TensorFlow implementation. | ||
def wrapped_fn(arg, **kwargs): | ||
return fn(arg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not following
All it does is keep the first position argument and drop everything else. What if you have multiple positional arguments? |
||
|
||
target_fn = wrapped_fn | ||
target_signature = [input_signature] | ||
target_params = [fn_parameters[0]] | ||
else: | ||
# Original code path for non-dict input signatures. | ||
target_fn = fn | ||
target_signature = input_signature | ||
target_params = fn_parameters[0 : len(input_signature)] | ||
|
||
if is_static: | ||
from jax.experimental import jax2tf | ||
|
||
jax_fn = jax2tf.convert(fn, **jax2tf_kwargs) | ||
jax_fn = jax2tf.convert(target_fn, **jax2tf_kwargs) | ||
jax_fn.__signature__ = inspect.Signature( | ||
parameters=fn_parameters[0 : len(input_signature)], | ||
parameters=target_params, | ||
return_annotation=fn_signature.return_annotation, | ||
) | ||
|
||
decorated_fn = tf.function( | ||
jax_fn, | ||
input_signature=input_signature, | ||
input_signature=target_signature, | ||
autograph=False, | ||
) | ||
else: | ||
|
@@ -85,7 +109,7 @@ def add_endpoint(self, name, fn, input_signature=None, **kwargs): | |
def stateless_fn(variables, *args, **kwargs): | ||
state_mapping = zip(self._backend_variables, variables) | ||
with StatelessScope(state_mapping=state_mapping) as scope: | ||
output = fn(*args, **kwargs) | ||
output = target_fn(*args, **kwargs) | ||
|
||
# Gather updated non-trainable variables | ||
non_trainable_variables = [] | ||
|
@@ -95,7 +119,7 @@ def stateless_fn(variables, *args, **kwargs): | |
return output, non_trainable_variables | ||
|
||
jax2tf_stateless_fn = self._convert_jax2tf_function( | ||
stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs | ||
stateless_fn, target_signature, jax2tf_kwargs=jax2tf_kwargs | ||
) | ||
|
||
def stateful_fn(*args, **kwargs): | ||
|
@@ -113,13 +137,13 @@ def stateful_fn(*args, **kwargs): | |
return output | ||
|
||
stateful_fn.__signature__ = inspect.Signature( | ||
parameters=fn_parameters[0 : len(input_signature)], | ||
parameters=target_params, | ||
return_annotation=fn_signature.return_annotation, | ||
) | ||
|
||
decorated_fn = tf.function( | ||
stateful_fn, | ||
input_signature=input_signature, | ||
input_signature=target_signature, | ||
autograph=False, | ||
) | ||
return decorated_fn | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,19 @@ def _track_layer(self, layer): | |
self._tf_trackable.non_trainable_variables += non_trainable_variables | ||
|
||
def add_endpoint(self, name, fn, input_signature=None, **kwargs): | ||
decorated_fn = tf.function( | ||
fn, input_signature=input_signature, autograph=False | ||
) | ||
if isinstance(input_signature, dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment about why this shouldn't be needed. |
||
# Create a simplified wrapper that handles both dict and | ||
# positional args. | ||
def wrapped_fn(arg, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment about why this is needed and dropping all but the first positional argument. |
||
return fn(arg) | ||
|
||
decorated_fn = tf.function( | ||
wrapped_fn, | ||
input_signature=[input_signature], | ||
autograph=False, | ||
) | ||
else: | ||
decorated_fn = tf.function( | ||
fn, input_signature=input_signature, autograph=False | ||
) | ||
return decorated_fn |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,9 +18,20 @@ def get_input_signature(model): | |
"before export." | ||
) | ||
if isinstance(model, (models.Functional, models.Sequential)): | ||
input_signature = tree.map_structure(make_input_spec, model.inputs) | ||
if isinstance(input_signature, list) and len(input_signature) > 1: | ||
input_signature = [input_signature] | ||
if hasattr(model, "_input_names") and model._input_names: | ||
if isinstance(model._inputs_struct, dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would |
||
# Create dictionary input signature while | ||
# preserving order. | ||
input_signature = { | ||
name: make_input_spec(tensor) | ||
for name, tensor in zip(model._input_names, model.inputs) | ||
} | ||
else: | ||
input_signature = tree.map_structure( | ||
make_input_spec, model.inputs | ||
) | ||
else: | ||
input_signature = tree.map_structure(make_input_spec, model.inputs) | ||
else: | ||
input_signature = _infer_input_signature_from_model(model) | ||
if not input_signature or not model._called: | ||
|
@@ -86,13 +97,35 @@ def make_input_spec(x): | |
|
||
def make_tf_tensor_spec(x): | ||
if isinstance(x, tf.TensorSpec): | ||
tensor_spec = x | ||
else: | ||
return x | ||
if isinstance(x, dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It is indeed not covered, showing it's not used: |
||
# Convert dict to ordered list with names preserved. | ||
return { | ||
name: tf.TensorSpec( | ||
shape=input_spec.shape, | ||
dtype=input_spec.dtype, | ||
name=name, | ||
) | ||
for name, spec in x.items() | ||
for input_spec in [make_input_spec(spec)] | ||
} | ||
elif isinstance(x, layers.InputSpec): | ||
input_spec = make_input_spec(x) | ||
tensor_spec = tf.TensorSpec( | ||
input_spec.shape, dtype=input_spec.dtype, name=input_spec.name | ||
return tf.TensorSpec( | ||
shape=input_spec.shape, dtype=input_spec.dtype, name=input_spec.name | ||
) | ||
else: | ||
if hasattr(x, "shape") and hasattr(x, "dtype"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. turn |
||
input_spec = make_input_spec(x) | ||
return tf.TensorSpec( | ||
shape=input_spec.shape, | ||
dtype=input_spec.dtype, | ||
name=getattr(input_spec, "name", None), | ||
) | ||
raise TypeError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an |
||
f"Unsupported x={x} of the type ({type(x)}). Supported types are: " | ||
"`keras.InputSpec`, `keras.KerasTensor` and backend tensor." | ||
) | ||
return tensor_spec | ||
|
||
|
||
def convert_spec_to_tensor(spec, replace_none_number=None): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,6 +102,11 @@ def __new__(cls, *args, **kwargs): | |
@tracking.no_automatic_dependency_tracking | ||
def __init__(self, inputs, outputs, name=None, **kwargs): | ||
if isinstance(inputs, dict): | ||
# This implementation relies on the deterministic order of | ||
# dictionary keys | ||
# Assumption from Python < 3.7. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused about this comment |
||
self._input_names = list(inputs.keys()) | ||
harshaljanjani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._inputs_struct = inputs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is already done by |
||
for k, v in inputs.items(): | ||
if isinstance(v, backend.KerasTensor) and k != v.name: | ||
warnings.warn( | ||
|
@@ -111,6 +116,9 @@ def __init__(self, inputs, outputs, name=None, **kwargs): | |
f"which has name '{v.name}'. Change the tensor name to " | ||
f"'{k}' (via `Input(..., name='{k}')`)" | ||
) | ||
else: | ||
self._input_names = None | ||
self._inputs_struct = inputs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is already done by |
||
|
||
trainable = kwargs.pop("trainable", None) | ||
flat_inputs = tree.flatten(inputs) | ||
|
@@ -267,7 +275,8 @@ def _adjust_input_rank(self, flat_inputs): | |
adjusted.append(ops.squeeze(x, axis=-1)) | ||
continue | ||
if x_rank == ref_rank - 1: | ||
if ref_shape[-1] == 1: | ||
# Check if ref_shape's last dimension is None (variable) or 1. | ||
if ref_shape[-1] is None or ref_shape[-1] == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this change needed? It seems risky to assume that |
||
adjusted.append(ops.expand_dims(x, axis=-1)) | ||
continue | ||
raise ValueError( | ||
|
@@ -284,42 +293,113 @@ def _adjust_input_rank(self, flat_inputs): | |
return adjusted | ||
|
||
def _standardize_inputs(self, inputs): | ||
raise_exception = False | ||
if isinstance(inputs, dict) and not isinstance( | ||
self._inputs_struct, dict | ||
if inputs is None: | ||
raise ValueError( | ||
"No inputs provided to the model (inputs is None)." | ||
) | ||
|
||
if self._inputs_struct is None: | ||
warnings.warn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead, this should be a |
||
"Model's input structure (_inputs_struct) is None." | ||
" This may lead to unexpected behavior.", | ||
UserWarning, | ||
) | ||
|
||
# Additional check: if both inputs and _inputs_struct are | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines 308 to 343 duplicates what |
||
# lists/tuples, warn if lengths differ. | ||
if isinstance(inputs, (list, tuple)) and isinstance( | ||
self._inputs_struct, (list, tuple) | ||
): | ||
# This is to avoid warning | ||
# when we have reconciable dict/list structs | ||
if hasattr(self._inputs_struct, "__len__") and all( | ||
isinstance(i, backend.KerasTensor) for i in self._inputs_struct | ||
if len(inputs) != len(self._inputs_struct): | ||
warnings.warn( | ||
f"Number of inputs ({len(inputs)}) does not match" | ||
f" the expected number ({len(self._inputs_struct)})" | ||
" based on self._inputs_struct. This may lead to unexpected" | ||
" behavior.", | ||
UserWarning, | ||
) | ||
|
||
if isinstance(inputs, dict) and isinstance(self._inputs_struct, dict): | ||
model_keys = set(self._inputs_struct.keys()) | ||
input_keys = set(inputs.keys()) | ||
missing = model_keys - input_keys | ||
if missing: | ||
raise ValueError( | ||
f"Input keys don't match model input keys. " | ||
f"Model expects: {list(self._inputs_struct.keys())}, " | ||
f"but got: {list(inputs.keys())}" | ||
) | ||
extra = input_keys - model_keys | ||
if extra: | ||
self._maybe_warn_inputs_struct_mismatch( | ||
inputs, raise_exception=False | ||
) | ||
filtered_inputs = {k: inputs[k] for k in model_keys} | ||
converted_inputs = tree.map_structure( | ||
ops.convert_to_tensor, filtered_inputs | ||
) | ||
# Flatten according to model's input structure to process ranks. | ||
flat_inputs = tree.flatten(converted_inputs) | ||
inputs = flat_inputs | ||
else: | ||
# Revert back to old logic for non-dict inputs. | ||
raise_exception = False | ||
if isinstance(inputs, dict) and not isinstance( | ||
self._inputs_struct, dict | ||
): | ||
expected_keys = set(i.name for i in self._inputs_struct) | ||
keys = set(inputs.keys()) | ||
if expected_keys.issubset(keys): | ||
inputs = [inputs[i.name] for i in self._inputs_struct] | ||
else: | ||
raise_exception = True | ||
elif isinstance(self._inputs_struct, backend.KerasTensor): | ||
if self._inputs_struct.name in inputs: | ||
inputs = [inputs[self._inputs_struct.name]] | ||
if isinstance(self._inputs_struct, (list, tuple)): | ||
all_kt = all( | ||
isinstance(kt, backend.KerasTensor) | ||
for kt in self._inputs_struct | ||
) | ||
if all_kt: | ||
expected_names = [kt.name for kt in self._inputs_struct] | ||
input_keys = set(inputs.keys()) | ||
if input_keys.issuperset(expected_names): | ||
inputs = [inputs[name] for name in expected_names] | ||
else: | ||
missing = set(expected_names) - input_keys | ||
raise ValueError( | ||
f"Missing input keys: {missing}. " | ||
f"Expected keys: {expected_names}, " | ||
f"received keys: {list(inputs.keys())}" | ||
) | ||
harshaljanjani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise_exception = True | ||
elif isinstance(self._inputs_struct, backend.KerasTensor): | ||
if self._inputs_struct.name in inputs: | ||
inputs = inputs[self._inputs_struct.name] | ||
else: | ||
raise_exception = True | ||
else: | ||
raise_exception = True | ||
else: | ||
raise_exception = True | ||
if ( | ||
isinstance(self._inputs_struct, dict) | ||
and not isinstance(inputs, dict) | ||
and list(self._inputs_struct.keys()) | ||
!= sorted(self._inputs_struct.keys()) | ||
): | ||
raise_exception = True | ||
self._maybe_warn_inputs_struct_mismatch( | ||
inputs, raise_exception=raise_exception | ||
) | ||
elif not isinstance(inputs, (list, tuple)) and isinstance( | ||
self._inputs_struct, (list, tuple) | ||
): | ||
inputs = [inputs] | ||
|
||
flat_inputs = tree.flatten(inputs) | ||
flat_inputs = self._convert_inputs_to_tensors(flat_inputs) | ||
return self._adjust_input_rank(flat_inputs) | ||
self._maybe_warn_inputs_struct_mismatch( | ||
inputs, raise_exception=raise_exception | ||
) | ||
|
||
if raise_exception: | ||
raise ValueError( | ||
f"The model's input structure doesn't match the input data." | ||
f" Model expects {self._inputs_struct}," | ||
f" but received {inputs}." | ||
) | ||
|
||
try: | ||
inputs = tree.flatten(inputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I liked the use of the name |
||
inputs = self._convert_inputs_to_tensors(inputs) | ||
except: | ||
raise ValueError( | ||
f"The model's input structure doesn't match the input data." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should never fail (unless you're passing objects that can't be converted to tensors. And if it does fail, I think the original error message would be most useful. So no need to catch here. Also, I believe the error message is incorrect. It's not an input structure problem if it fails. |
||
f" Model expects {self._inputs_struct}," | ||
f" but received {inputs}." | ||
) | ||
|
||
return self._adjust_input_rank(inputs) | ||
|
||
@property | ||
def input(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be needed.
input_signature
should always be a list because the model can have multiple positional arguments. If the inputs are a dict,input_signature
should be a list of one dict already. If you need to add the outer list here, it means that an incorrectinput_signature
was passed in the first place.