Skip to content

Pydantic for runtime type checking minimal example #354

@adonath

Description

@adonath

I was trying to use Pydantic as a runtime type checker and the following minimal example seems to work fine:

from jaxtyping import Array, Float, jaxtyped
from pydantic import validate_call, ConfigDict
from jax import numpy as jnp


typechecker = validate_call(config=ConfigDict(arbitrary_types_allowed=True))


@jaxtyped(typechecker=typechecker)
def batch_outer_product(
    x: Float[Array, "b c1"], y: Float[Array, "b c2"]
) -> Float[Array, "b c1 c2"]:
    return x[:, :, None] * y[:, None, :]


# passes
x = jnp.ones((5, 3))
y = jnp.ones((5, 12))
result = batch_outer_product(x, y)
assert result.shape == (5, 3, 12)

# fails
x = jnp.ones((5, 3))
y = jnp.ones((12, 5))
batch_outer_product(x, y)

The second call fails with:

jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of __main__.batch_outer_product.
The problem arose whilst typechecking parameter 'y'.
Actual value: f32[12,5](jax)
Expected type: <class 'Float[Array, 'b c2']'>.
----------------------
Called with parameters: {'x': f32[5,3](jax), 'y': f32[12,5](jax)}
Parameter annotations: (x: Float[Array, 'b c1'], y: Float[Array, 'b c2']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
b=5
c1=3

Which seems to be the expected behavior. I can spend a bit more time to explore other cases (dataclasses etc.) and see whether everything works, then I can turn this into a docs example. Any thoughts @patrick-kidger?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions