|
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import numpy.typing as npt |
| 6 | +import jaxtyping |
6 | 7 | import pytest |
7 | 8 |
|
8 | 9 | import serde |
@@ -89,6 +90,66 @@ class NumpyDate: |
89 | 90 |
|
90 | 91 | assert de(NumpyDate, se(date_test)) == date_test |
91 | 92 |
|
| 93 | + @serde.serde(**opt) |
| 94 | + class NumpyJaxtyping: |
| 95 | + float_: jaxtyping.Float[np.ndarray, "2 2"] # noqa: F722 |
| 96 | + float16: jaxtyping.Float16[np.ndarray, "2 2"] # noqa: F722 |
| 97 | + float32: jaxtyping.Float32[np.ndarray, "2 2"] # noqa: F722 |
| 98 | + float64: jaxtyping.Float64[np.ndarray, "2 2"] # noqa: F722 |
| 99 | + inexact: jaxtyping.Inexact[np.ndarray, "2 2"] # noqa: F722 |
| 100 | + int_: jaxtyping.Int[np.ndarray, "2 2"] # noqa: F722 |
| 101 | + int8: jaxtyping.Int8[np.ndarray, "2 2"] # noqa: F722 |
| 102 | + int16: jaxtyping.Int16[np.ndarray, "2 2"] # noqa: F722 |
| 103 | + int32: jaxtyping.Int32[np.ndarray, "2 2"] # noqa: F722 |
| 104 | + int64: jaxtyping.Int64[np.ndarray, "2 2"] # noqa: F722 |
| 105 | + integer: jaxtyping.Integer[np.ndarray, "2 2"] # noqa: F722 |
| 106 | + uint: jaxtyping.UInt[np.ndarray, "2 2"] # noqa: F722 |
| 107 | + uint8: jaxtyping.UInt8[np.ndarray, "2 2"] # noqa: F722 |
| 108 | + uint16: jaxtyping.UInt16[np.ndarray, "2 2"] # noqa: F722 |
| 109 | + uint32: jaxtyping.UInt32[np.ndarray, "2 2"] # noqa: F722 |
| 110 | + uint64: jaxtyping.UInt64[np.ndarray, "2 2"] # noqa: F722 |
| 111 | + |
| 112 | + def __eq__(self, other): |
| 113 | + return ( |
| 114 | + (self.float_ == other.float_).all() |
| 115 | + and (self.float16 == other.float16).all() |
| 116 | + and (self.float32 == other.float32).all() |
| 117 | + and (self.float64 == other.float64).all() |
| 118 | + and (self.inexact == other.inexact).all() |
| 119 | + and (self.int_ == other.int_).all() |
| 120 | + and (self.int8 == other.int8).all() |
| 121 | + and (self.int16 == other.int16).all() |
| 122 | + and (self.int32 == other.int32).all() |
| 123 | + and (self.int64 == other.int64).all() |
| 124 | + and (self.integer == other.integer).all() |
| 125 | + and (self.uint == other.uint).all() |
| 126 | + and (self.uint8 == other.uint8).all() |
| 127 | + and (self.uint16 == other.uint16).all() |
| 128 | + and (self.uint32 == other.uint32).all() |
| 129 | + and (self.uint64 == other.uint64).all() |
| 130 | + ) |
| 131 | + |
| 132 | + jaxtyping_test = NumpyJaxtyping( |
| 133 | + float_=np.array([[1, 2], [3, 4]], dtype=np.float_), |
| 134 | + float16=np.array([[5, 6], [7, 8]], dtype=np.float16), |
| 135 | + float32=np.array([[9, 10], [11, 12]], dtype=np.float32), |
| 136 | + float64=np.array([[13, 14], [15, 16]], dtype=np.float64), |
| 137 | + inexact=np.array([[17, 18], [19, 20]], dtype=np.float_), |
| 138 | + int_=np.array([[21, 22], [23, 24]], dtype=np.int_), |
| 139 | + int8=np.array([[25, 26], [27, 28]], dtype=np.int8), |
| 140 | + int16=np.array([[29, 30], [31, 32]], dtype=np.int16), |
| 141 | + int32=np.array([[33, 34], [35, 36]], dtype=np.int32), |
| 142 | + int64=np.array([[37, 38], [39, 40]], dtype=np.int64), |
| 143 | + integer=np.array([[41, 42], [43, 44]], dtype=np.int_), |
| 144 | + uint=np.array([[45, 46], [47, 48]], dtype=np.uint), |
| 145 | + uint8=np.array([[49, 50], [51, 52]], dtype=np.uint8), |
| 146 | + uint16=np.array([[53, 54], [55, 56]], dtype=np.uint16), |
| 147 | + uint32=np.array([[57, 58], [59, 60]], dtype=np.uint32), |
| 148 | + uint64=np.array([[61, 62], [63, 64]], dtype=np.uint64), |
| 149 | + ) |
| 150 | + |
| 151 | + assert de(NumpyJaxtyping, se(jaxtyping_test)) == jaxtyping_test |
| 152 | + |
92 | 153 |
|
93 | 154 | @pytest.mark.parametrize("opt", opt_case, ids=opt_case_ids()) |
94 | 155 | @pytest.mark.parametrize("se,de", format_json + format_msgpack) |
|
0 commit comments