-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
This
import jax
from jax import export
from jax import numpy as jnp
import jaxtyping as jt
import typeguard
@jt.jaxtyped(typechecker=typeguard.typechecked)
def f(
x: jt.Float[jt.Array, "*#B"],
) -> jt.Float[jt.Array, "*#B"]:
return x * jnp.sum(x) ** 2
dtype = jnp.float32
x_shape = export.symbolic_shape("b")
export.export(jax.jit(f))(
jax.ShapeDtypeStruct(x_shape, dtype)
)
fails with
jaxtyping.TypeCheckError: Type-check error whilst checking the return value of __main__.f.
Actual value: f32[b](jax)
Expected type: Float[Array, '*#B'].
----------------------
Called with parameters: {'x': f32[b](jax)}
Parameter annotations: (x: Float[Array, '*#B']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
B=(b,)
The problem seems to be *# (it works without either).
Metadata
Metadata
Assignees
Labels
No labels