|
| 1 | +import numpy |
| 2 | +from jaxtyping import ( |
| 3 | + Float, |
| 4 | + Float16, |
| 5 | + Float32, |
| 6 | + Float64, |
| 7 | + Inexact, |
| 8 | + Int, |
| 9 | + Int8, |
| 10 | + Int16, |
| 11 | + Int32, |
| 12 | + Int64, |
| 13 | + Integer, |
| 14 | + UInt, |
| 15 | + UInt8, |
| 16 | + UInt16, |
| 17 | + UInt32, |
| 18 | + UInt64, |
| 19 | +) |
| 20 | +from serde import serde |
| 21 | +from serde.json import from_json, to_json |
| 22 | + |
| 23 | + |
| 24 | +@serde |
| 25 | +class Foo: |
| 26 | + float_: Float[numpy.ndarray, "3 3"] |
| 27 | + float16: Float16[numpy.ndarray, "3 3"] |
| 28 | + float32: Float32[numpy.ndarray, "3 3"] |
| 29 | + float64: Float64[numpy.ndarray, "3 3"] |
| 30 | + inexact: Inexact[numpy.ndarray, "3 3"] |
| 31 | + int_: Int[numpy.ndarray, "3 3"] |
| 32 | + int8: Int8[numpy.ndarray, "3 3"] |
| 33 | + int16: Int16[numpy.ndarray, "3 3"] |
| 34 | + int32: Int32[numpy.ndarray, "3 3"] |
| 35 | + int64: Int64[numpy.ndarray, "3 3"] |
| 36 | + integer: Integer[numpy.ndarray, "3 3"] |
| 37 | + uint: UInt[numpy.ndarray, "3 3"] |
| 38 | + uint8: UInt8[numpy.ndarray, "3 3"] |
| 39 | + uint16: UInt16[numpy.ndarray, "3 3"] |
| 40 | + uint32: UInt32[numpy.ndarray, "3 3"] |
| 41 | + uint64: UInt64[numpy.ndarray, "3 3"] |
| 42 | + |
| 43 | + |
| 44 | +def main() -> None: |
| 45 | + foo = Foo( |
| 46 | + float_=numpy.zeros((3, 3), dtype=float), |
| 47 | + float16=numpy.zeros((3, 3), dtype=numpy.float16), |
| 48 | + float32=numpy.zeros((3, 3), dtype=numpy.float32), |
| 49 | + float64=numpy.zeros((3, 3), dtype=numpy.float64), |
| 50 | + inexact=numpy.zeros((3, 3), dtype=numpy.inexact), |
| 51 | + int_=numpy.zeros((3, 3), dtype=int), |
| 52 | + int8=numpy.zeros((3, 3), dtype=numpy.int8), |
| 53 | + int16=numpy.zeros((3, 3), dtype=numpy.int16), |
| 54 | + int32=numpy.zeros((3, 3), dtype=numpy.int32), |
| 55 | + int64=numpy.zeros((3, 3), dtype=numpy.int64), |
| 56 | + integer=numpy.zeros((3, 3), dtype=numpy.integer), |
| 57 | + uint=numpy.zeros((3, 3), dtype=numpy.uint), |
| 58 | + uint8=numpy.zeros((3, 3), dtype=numpy.uint8), |
| 59 | + uint16=numpy.zeros((3, 3), dtype=numpy.uint16), |
| 60 | + uint32=numpy.zeros((3, 3), dtype=numpy.uint32), |
| 61 | + uint64=numpy.zeros((3, 3), dtype=numpy.uint64), |
| 62 | + ) |
| 63 | + |
| 64 | + print(f"Into Json: {to_json(foo)}") |
| 65 | + |
| 66 | + s = """ |
| 67 | + { |
| 68 | + "float_": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], |
| 69 | + "float16": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], |
| 70 | + "float32": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], |
| 71 | + "float64": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], |
| 72 | + "inexact": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], |
| 73 | + "int_": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 74 | + "int8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 75 | + "int16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 76 | + "int32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 77 | + "int64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 78 | + "integer": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 79 | + "uint": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 80 | + "uint8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 81 | + "uint16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 82 | + "uint32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], |
| 83 | + "uint64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]] |
| 84 | + } |
| 85 | + """ |
| 86 | + print(f"From Json: {from_json(Foo, s)}") |
| 87 | + |
| 88 | + |
| 89 | +if __name__ == "__main__": |
| 90 | + main() |
0 commit comments