Skip to content

Commit 2a012c1

Browse files
committed
add jaxtyping test
1 parent 86cfda9 commit 2a012c1

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tests/test_numpy.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import numpy.typing as npt
6+
import jaxtyping
67
import pytest
78

89
import serde
@@ -89,6 +90,66 @@ class NumpyDate:
8990

9091
assert de(NumpyDate, se(date_test)) == date_test
9192

93+
@serde.serde(**opt)
94+
class NumpyJaxtyping:
95+
float_: jaxtyping.Float[np.ndarray, "2 2"] # noqa: F722
96+
float16: jaxtyping.Float16[np.ndarray, "2 2"] # noqa: F722
97+
float32: jaxtyping.Float32[np.ndarray, "2 2"] # noqa: F722
98+
float64: jaxtyping.Float64[np.ndarray, "2 2"] # noqa: F722
99+
inexact: jaxtyping.Inexact[np.ndarray, "2 2"] # noqa: F722
100+
int_: jaxtyping.Int[np.ndarray, "2 2"] # noqa: F722
101+
int8: jaxtyping.Int8[np.ndarray, "2 2"] # noqa: F722
102+
int16: jaxtyping.Int16[np.ndarray, "2 2"] # noqa: F722
103+
int32: jaxtyping.Int32[np.ndarray, "2 2"] # noqa: F722
104+
int64: jaxtyping.Int64[np.ndarray, "2 2"] # noqa: F722
105+
integer: jaxtyping.Integer[np.ndarray, "2 2"] # noqa: F722
106+
uint: jaxtyping.UInt[np.ndarray, "2 2"] # noqa: F722
107+
uint8: jaxtyping.UInt8[np.ndarray, "2 2"] # noqa: F722
108+
uint16: jaxtyping.UInt16[np.ndarray, "2 2"] # noqa: F722
109+
uint32: jaxtyping.UInt32[np.ndarray, "2 2"] # noqa: F722
110+
uint64: jaxtyping.UInt64[np.ndarray, "2 2"] # noqa: F722
111+
112+
def __eq__(self, other):
113+
return (
114+
(self.float_ == other.float_).all()
115+
and (self.float16 == other.float16).all()
116+
and (self.float32 == other.float32).all()
117+
and (self.float64 == other.float64).all()
118+
and (self.inexact == other.inexact).all()
119+
and (self.int_ == other.int_).all()
120+
and (self.int8 == other.int8).all()
121+
and (self.int16 == other.int16).all()
122+
and (self.int32 == other.int32).all()
123+
and (self.int64 == other.int64).all()
124+
and (self.integer == other.integer).all()
125+
and (self.uint == other.uint).all()
126+
and (self.uint8 == other.uint8).all()
127+
and (self.uint16 == other.uint16).all()
128+
and (self.uint32 == other.uint32).all()
129+
and (self.uint64 == other.uint64).all()
130+
)
131+
132+
jaxtyping_test = NumpyJaxtyping(
133+
float_=np.array([[1, 2], [3, 4]], dtype=np.float_),
134+
float16=np.array([[5, 6], [7, 8]], dtype=np.float16),
135+
float32=np.array([[9, 10], [11, 12]], dtype=np.float32),
136+
float64=np.array([[13, 14], [15, 16]], dtype=np.float64),
137+
inexact=np.array([[17, 18], [19, 20]], dtype=np.float_),
138+
int_=np.array([[21, 22], [23, 24]], dtype=np.int_),
139+
int8=np.array([[25, 26], [27, 28]], dtype=np.int8),
140+
int16=np.array([[29, 30], [31, 32]], dtype=np.int16),
141+
int32=np.array([[33, 34], [35, 36]], dtype=np.int32),
142+
int64=np.array([[37, 38], [39, 40]], dtype=np.int64),
143+
integer=np.array([[41, 42], [43, 44]], dtype=np.int_),
144+
uint=np.array([[45, 46], [47, 48]], dtype=np.uint),
145+
uint8=np.array([[49, 50], [51, 52]], dtype=np.uint8),
146+
uint16=np.array([[53, 54], [55, 56]], dtype=np.uint16),
147+
uint32=np.array([[57, 58], [59, 60]], dtype=np.uint32),
148+
uint64=np.array([[61, 62], [63, 64]], dtype=np.uint64),
149+
)
150+
151+
assert de(NumpyJaxtyping, se(jaxtyping_test)) == jaxtyping_test
152+
92153

93154
@pytest.mark.parametrize("opt", opt_case, ids=opt_case_ids())
94155
@pytest.mark.parametrize("se,de", format_json + format_msgpack)

0 commit comments

Comments
 (0)