-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Labels
questionUser queriesUser queries
Description
How to specify exact per leaf shape of a pytree, say a dict? And even further, the graph structure of a pytree.
For example, function f takes a dict as input:
@jaxtyped(typechecker=beartype)
def f(state: Dict{'x': Float[Array, 'b 10'], 'y': Float[Array, 'b 1']}):
...
# Example valid input
valid_state = {
'x': jnp.ones((3, 10)), # b=3
'y': jnp.zeros((3, 1)) # b=3, consistent
}
# Example invalid input (wrong shape for 'x')
invalid_state = {
'x': jnp.ones((3, 99)), # Shape is not 'b 10'
'y': jnp.zeros((3, 1))
}
f(valid_state)
try:
f(invalid_state)
except Exception as e:
print(f"\nError with invalid_state:\n{e}")(The above snippet is not going to work)
Hope the type checker can check every leaf's shape and the graph structure.
PyTree[Float[Array, 'b ...']] is good but not fine-grained.
I think this feature is quite intuitive, e.g., in RL, jax env's step function takes a complex state. Maybe there is ways or workaround but I failed to find one. Sorry for possible ignorance.
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries