Skip to content

Commit 129e030

Browse files
authored
Merge pull request #576 from pablovela5620/jaxtyping
initial working jaxtyping serializing/deserializing
2 parents b29cef4 + 2a012c1 commit 129e030

File tree

6 files changed

+214
-24
lines changed

6 files changed

+214
-24
lines changed

examples/type_numpy_jaxtyping.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy
2+
from jaxtyping import (
3+
Float,
4+
Float16,
5+
Float32,
6+
Float64,
7+
Inexact,
8+
Int,
9+
Int8,
10+
Int16,
11+
Int32,
12+
Int64,
13+
Integer,
14+
UInt,
15+
UInt8,
16+
UInt16,
17+
UInt32,
18+
UInt64,
19+
)
20+
from serde import serde
21+
from serde.json import from_json, to_json
22+
23+
24+
@serde
25+
class Foo:
26+
float_: Float[numpy.ndarray, "3 3"]
27+
float16: Float16[numpy.ndarray, "3 3"]
28+
float32: Float32[numpy.ndarray, "3 3"]
29+
float64: Float64[numpy.ndarray, "3 3"]
30+
inexact: Inexact[numpy.ndarray, "3 3"]
31+
int_: Int[numpy.ndarray, "3 3"]
32+
int8: Int8[numpy.ndarray, "3 3"]
33+
int16: Int16[numpy.ndarray, "3 3"]
34+
int32: Int32[numpy.ndarray, "3 3"]
35+
int64: Int64[numpy.ndarray, "3 3"]
36+
integer: Integer[numpy.ndarray, "3 3"]
37+
uint: UInt[numpy.ndarray, "3 3"]
38+
uint8: UInt8[numpy.ndarray, "3 3"]
39+
uint16: UInt16[numpy.ndarray, "3 3"]
40+
uint32: UInt32[numpy.ndarray, "3 3"]
41+
uint64: UInt64[numpy.ndarray, "3 3"]
42+
43+
44+
def main() -> None:
45+
foo = Foo(
46+
float_=numpy.zeros((3, 3), dtype=float),
47+
float16=numpy.zeros((3, 3), dtype=numpy.float16),
48+
float32=numpy.zeros((3, 3), dtype=numpy.float32),
49+
float64=numpy.zeros((3, 3), dtype=numpy.float64),
50+
inexact=numpy.zeros((3, 3), dtype=numpy.inexact),
51+
int_=numpy.zeros((3, 3), dtype=int),
52+
int8=numpy.zeros((3, 3), dtype=numpy.int8),
53+
int16=numpy.zeros((3, 3), dtype=numpy.int16),
54+
int32=numpy.zeros((3, 3), dtype=numpy.int32),
55+
int64=numpy.zeros((3, 3), dtype=numpy.int64),
56+
integer=numpy.zeros((3, 3), dtype=numpy.integer),
57+
uint=numpy.zeros((3, 3), dtype=numpy.uint),
58+
uint8=numpy.zeros((3, 3), dtype=numpy.uint8),
59+
uint16=numpy.zeros((3, 3), dtype=numpy.uint16),
60+
uint32=numpy.zeros((3, 3), dtype=numpy.uint32),
61+
uint64=numpy.zeros((3, 3), dtype=numpy.uint64),
62+
)
63+
64+
print(f"Into Json: {to_json(foo)}")
65+
66+
s = """
67+
{
68+
"float_": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
69+
"float16": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
70+
"float32": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
71+
"float64": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
72+
"inexact": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
73+
"int_": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
74+
"int8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
75+
"int16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
76+
"int32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
77+
"int64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
78+
"integer": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
79+
"uint": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
80+
"uint8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
81+
"uint16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
82+
"uint32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
83+
"uint64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
84+
}
85+
"""
86+
print(f"From Json: {from_json(Foo, s)}")
87+
88+
89+
if __name__ == "__main__":
90+
main()

pyproject.toml

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ packages = [
1111
{ include = "serde" },
1212
]
1313
classifiers=[
14-
"Development Status :: 4 - Beta",
15-
"Intended Audience :: Developers",
16-
"License :: OSI Approved :: MIT License",
17-
"Programming Language :: Python :: 3.9",
18-
"Programming Language :: Python :: 3.10",
19-
"Programming Language :: Python :: 3.11",
20-
"Programming Language :: Python :: 3.12",
21-
"Programming Language :: Python :: Implementation :: CPython",
22-
"Programming Language :: Python :: Implementation :: PyPy",
23-
]
14+
"Development Status :: 4 - Beta",
15+
"Intended Audience :: Developers",
16+
"License :: OSI Approved :: MIT License",
17+
"Programming Language :: Python :: 3.9",
18+
"Programming Language :: Python :: 3.10",
19+
"Programming Language :: Python :: 3.11",
20+
"Programming Language :: Python :: 3.12",
21+
"Programming Language :: Python :: Implementation :: CPython",
22+
"Programming Language :: Python :: Implementation :: PyPy",
23+
]
2424

2525
[tool.poetry.dependencies]
2626
python = "^3.9.0"
@@ -33,11 +33,12 @@ tomli = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional
3333
tomli-w = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional = true }
3434
pyyaml = { version = "*", markers = "extra == 'yaml' or extra == 'all'", optional = true }
3535
numpy = [
36-
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true },
37-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true },
38-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true },
39-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true },
36+
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true },
37+
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true },
38+
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true },
39+
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true },
4040
]
41+
jaxtyping = { version = "*", markers = "extra == 'jaxtyping' or extra == 'all'", optional = true }
4142
orjson = { version = "*", markers = "extra == 'orjson' or extra == 'all'", optional = true }
4243
plum-dispatch = ">=2,<2.3"
4344
beartype = ">=0.18.4"
@@ -49,10 +50,10 @@ tomli = { version = "*", markers = "python_version <= '3.11.0'" }
4950
tomli-w = "*"
5051
msgpack = "*"
5152
numpy = [
52-
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0'" },
53-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10'" },
54-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11'" },
55-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12'" },
53+
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0'" },
54+
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10'" },
55+
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11'" },
56+
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12'" },
5657
]
5758
mypy = "==1.10.1"
5859
pytest = "*"
@@ -68,6 +69,7 @@ types-PyYAML = "^6.0.9"
6869
msgpack-types = "^0.3"
6970
envclasses = "^0.3.1"
7071
jedi = "*"
72+
jaxtyping = "*"
7173

7274
[tool.poetry.extras]
7375
msgpack = ["msgpack"]
@@ -76,7 +78,8 @@ toml = ["tomli", "tomli-w"]
7678
yaml = ["pyyaml"]
7779
orjson = ["orjson"]
7880
sqlalchemy = ["sqlalchemy"]
79-
all = ["msgpack", "tomli", "tomli-w", "pyyaml", "numpy", "orjson", "sqlalchemy"]
81+
jaxtyping = ["jaxtyping"]
82+
all = ["msgpack", "tomli", "tomli-w", "pyyaml", "numpy", "orjson", "sqlalchemy", "jaxtyping"]
8083

8184
[build-system]
8285
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"]
@@ -145,16 +148,25 @@ exclude = [
145148
"tests/test_sqlalchemy.py",
146149
]
147150

151+
[[tool.mypy.overrides]]
152+
# to avoid complaints about generic type ndarray
153+
module = "examples.type_numpy_jaxtyping"
154+
ignore_errors = true
155+
148156
[tool.ruff]
149157
select = [
150-
"E", # pycodestyle errors
151-
"W", # pycodestyle warnings
152-
"F", # pyflakes
153-
"C", # flake8-comprehensions
154-
"B", # flake8-bugbear
158+
"E", # pycodestyle errors
159+
"W", # pycodestyle warnings
160+
"F", # pyflakes
161+
"C", # flake8-comprehensions
162+
"B", # flake8-bugbear
155163
]
156164
ignore = ["B904"]
157165
line-length = 100
158166

159167
[tool.ruff.lint.mccabe]
160168
max-complexity = 30
169+
170+
[tool.ruff.per-file-ignores]
171+
# https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
172+
"examples/type_numpy_jaxtyping.py" = ["F722"]

serde/de.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@
8989
deserialize_numpy_array,
9090
deserialize_numpy_scalar,
9191
deserialize_numpy_array_direct,
92+
deserialize_numpy_jaxtyping_array,
9293
is_numpy_array,
94+
is_numpy_jaxtyping,
9395
is_numpy_scalar,
9496
)
9597

@@ -749,6 +751,9 @@ def render(self, arg: DeField[Any]) -> str:
749751
elif is_numpy_array(arg.type):
750752
self.import_numpy = True
751753
res = deserialize_numpy_array(arg)
754+
elif is_numpy_jaxtyping(arg.type):
755+
self.import_numpy = True
756+
res = deserialize_numpy_jaxtyping_array(arg)
752757
elif is_union(arg.type):
753758
res = self.union_func(arg)
754759
elif is_str_serializable(arg.type):

serde/numpy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ def is_numpy_array(typ) -> bool:
7373
typ = origin
7474
return typ is np.ndarray
7575

76+
def is_numpy_jaxtyping(typ) -> bool:
77+
try:
78+
origin = get_origin(typ)
79+
if origin is not None:
80+
typ = origin
81+
return typ is not np.ndarray and issubclass(typ, np.ndarray)
82+
except TypeError:
83+
return False
84+
7685
def serialize_numpy_array(arg) -> str:
7786
return f"{arg.varname}.tolist()"
7887

@@ -86,6 +95,10 @@ def deserialize_numpy_array(arg) -> str:
8695
dtype = fullname(arg[1][0].type)
8796
return f"numpy.array({arg.data}, dtype={dtype})"
8897

98+
def deserialize_numpy_jaxtyping_array(arg) -> str:
99+
dtype = f"numpy.{arg.type.dtypes[-1]}"
100+
return f"numpy.array({arg.data}, dtype={dtype})"
101+
89102
def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
90103
if is_bare_numpy_array(typ):
91104
return np.array(arg)
@@ -111,6 +124,9 @@ def deserialize_numpy_scalar(arg):
111124
def is_numpy_array(typ) -> bool:
112125
return False
113126

127+
def is_numpy_jaxtyping(typ) -> bool:
128+
return False
129+
114130
def serialize_numpy_array(arg) -> str:
115131
return ""
116132

@@ -120,5 +136,8 @@ def serialize_numpy_datetime(arg) -> str:
120136
def deserialize_numpy_array(arg) -> str:
121137
return ""
122138

139+
def deserialize_numpy_jaxtyping_array(arg) -> str:
140+
return ""
141+
123142
def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
124143
return arg

serde/se.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
)
7777
from .numpy import (
7878
is_numpy_array,
79+
is_numpy_jaxtyping,
7980
is_numpy_datetime,
8081
is_numpy_scalar,
8182
serialize_numpy_array,
@@ -751,6 +752,8 @@ def render(self, arg: SeField[Any]) -> str:
751752
res = serialize_numpy_scalar(arg)
752753
elif is_numpy_array(arg.type):
753754
res = serialize_numpy_array(arg)
755+
elif is_numpy_jaxtyping(arg.type):
756+
res = serialize_numpy_array(arg)
754757
elif is_primitive(arg.type):
755758
res = self.primitive(arg)
756759
elif is_union(arg.type):

tests/test_numpy.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import numpy.typing as npt
6+
import jaxtyping
67
import pytest
78

89
import serde
@@ -89,6 +90,66 @@ class NumpyDate:
8990

9091
assert de(NumpyDate, se(date_test)) == date_test
9192

93+
@serde.serde(**opt)
94+
class NumpyJaxtyping:
95+
float_: jaxtyping.Float[np.ndarray, "2 2"] # noqa: F722
96+
float16: jaxtyping.Float16[np.ndarray, "2 2"] # noqa: F722
97+
float32: jaxtyping.Float32[np.ndarray, "2 2"] # noqa: F722
98+
float64: jaxtyping.Float64[np.ndarray, "2 2"] # noqa: F722
99+
inexact: jaxtyping.Inexact[np.ndarray, "2 2"] # noqa: F722
100+
int_: jaxtyping.Int[np.ndarray, "2 2"] # noqa: F722
101+
int8: jaxtyping.Int8[np.ndarray, "2 2"] # noqa: F722
102+
int16: jaxtyping.Int16[np.ndarray, "2 2"] # noqa: F722
103+
int32: jaxtyping.Int32[np.ndarray, "2 2"] # noqa: F722
104+
int64: jaxtyping.Int64[np.ndarray, "2 2"] # noqa: F722
105+
integer: jaxtyping.Integer[np.ndarray, "2 2"] # noqa: F722
106+
uint: jaxtyping.UInt[np.ndarray, "2 2"] # noqa: F722
107+
uint8: jaxtyping.UInt8[np.ndarray, "2 2"] # noqa: F722
108+
uint16: jaxtyping.UInt16[np.ndarray, "2 2"] # noqa: F722
109+
uint32: jaxtyping.UInt32[np.ndarray, "2 2"] # noqa: F722
110+
uint64: jaxtyping.UInt64[np.ndarray, "2 2"] # noqa: F722
111+
112+
def __eq__(self, other):
113+
return (
114+
(self.float_ == other.float_).all()
115+
and (self.float16 == other.float16).all()
116+
and (self.float32 == other.float32).all()
117+
and (self.float64 == other.float64).all()
118+
and (self.inexact == other.inexact).all()
119+
and (self.int_ == other.int_).all()
120+
and (self.int8 == other.int8).all()
121+
and (self.int16 == other.int16).all()
122+
and (self.int32 == other.int32).all()
123+
and (self.int64 == other.int64).all()
124+
and (self.integer == other.integer).all()
125+
and (self.uint == other.uint).all()
126+
and (self.uint8 == other.uint8).all()
127+
and (self.uint16 == other.uint16).all()
128+
and (self.uint32 == other.uint32).all()
129+
and (self.uint64 == other.uint64).all()
130+
)
131+
132+
jaxtyping_test = NumpyJaxtyping(
133+
float_=np.array([[1, 2], [3, 4]], dtype=np.float_),
134+
float16=np.array([[5, 6], [7, 8]], dtype=np.float16),
135+
float32=np.array([[9, 10], [11, 12]], dtype=np.float32),
136+
float64=np.array([[13, 14], [15, 16]], dtype=np.float64),
137+
inexact=np.array([[17, 18], [19, 20]], dtype=np.float_),
138+
int_=np.array([[21, 22], [23, 24]], dtype=np.int_),
139+
int8=np.array([[25, 26], [27, 28]], dtype=np.int8),
140+
int16=np.array([[29, 30], [31, 32]], dtype=np.int16),
141+
int32=np.array([[33, 34], [35, 36]], dtype=np.int32),
142+
int64=np.array([[37, 38], [39, 40]], dtype=np.int64),
143+
integer=np.array([[41, 42], [43, 44]], dtype=np.int_),
144+
uint=np.array([[45, 46], [47, 48]], dtype=np.uint),
145+
uint8=np.array([[49, 50], [51, 52]], dtype=np.uint8),
146+
uint16=np.array([[53, 54], [55, 56]], dtype=np.uint16),
147+
uint32=np.array([[57, 58], [59, 60]], dtype=np.uint32),
148+
uint64=np.array([[61, 62], [63, 64]], dtype=np.uint64),
149+
)
150+
151+
assert de(NumpyJaxtyping, se(jaxtyping_test)) == jaxtyping_test
152+
92153

93154
@pytest.mark.parametrize("opt", opt_case, ids=opt_case_ids())
94155
@pytest.mark.parametrize("se,de", format_json + format_msgpack)

0 commit comments

Comments
 (0)