Skip to content

Jaxtyping annotations don't work with pyserde #352

@Ruhrpottpatriot

Description

@Ruhrpottpatriot

I'm using Jaxtyping to make sure that my tensors are in the correct shape. However, I also need to serialise the tensors to an API, for which I'm using pyserde with a custom global serialiser.
If the types of the tensor are plain torch.Tensor the custom serialisers works and I can serialise the classes to JSON or other pyserde supported formats. However, when I annotate the tensors with Jaxtyping, e.g. Float32[torch.Tensor, "4"] the serialiser fails with the error:

serde.compat.SerdeError: Unsupported type: Tensor

Minimal (not) working example, uncomment the field to get the error:

from typing import Annotated, Any
from plum import dispatch
import torch

from jaxtyping import Float32
import serde
from serde.json import to_json


class Serializer:
    @dispatch
    def serialize(self, value: torch.Tensor) -> Any:
        return {
            "__tensor__": True,
            "dtype": str(value.dtype),
            "shape": list(value.shape),
            "data": value.cpu().tolist(),
        }


serde.add_serializer(Serializer())


@serde.serde
class Foo:
    tensor_works: torch.Tensor

    annotated_works: Annotated[torch.Tensor, "4"]

    # jax: Float32[torch.Tensor, "4"]


foo = Foo(
    tensor_works=torch.tensor([1000.0, 2000.0, 3000.0]),
    annotated_works=torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float32),
    # jax=torch.tensor([100.0, 0.0, 0.0, 0.0], dtype=torch.float32), <- Doesn't work
)

j = to_json(foo)
print(j)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions