-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Labels
questionUser queriesUser queries
Description
I'm using Jaxtyping to make sure that my tensors are in the correct shape. However, I also need to serialise the tensors to an API, for which I'm using pyserde with a custom global serialiser.
If the types of the tensor are plain torch.Tensor the custom serialisers works and I can serialise the classes to JSON or other pyserde supported formats. However, when I annotate the tensors with Jaxtyping, e.g. Float32[torch.Tensor, "4"] the serialiser fails with the error:
serde.compat.SerdeError: Unsupported type: Tensor
Minimal (not) working example, uncomment the field to get the error:
from typing import Annotated, Any
from plum import dispatch
import torch
from jaxtyping import Float32
import serde
from serde.json import to_json
class Serializer:
@dispatch
def serialize(self, value: torch.Tensor) -> Any:
return {
"__tensor__": True,
"dtype": str(value.dtype),
"shape": list(value.shape),
"data": value.cpu().tolist(),
}
serde.add_serializer(Serializer())
@serde.serde
class Foo:
tensor_works: torch.Tensor
annotated_works: Annotated[torch.Tensor, "4"]
# jax: Float32[torch.Tensor, "4"]
foo = Foo(
tensor_works=torch.tensor([1000.0, 2000.0, 3000.0]),
annotated_works=torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float32),
# jax=torch.tensor([100.0, 0.0, 0.0, 0.0], dtype=torch.float32), <- Doesn't work
)
j = to_json(foo)
print(j)
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries