Skip to content

Commit 22d0b40

Browse files
authored
Merge pull request #649 from krylowicz/numpy-v2
Support `numpy>2.0.0`
2 parents e7c44a8 + e191403 commit 22d0b40

File tree

3 files changed

+15
-17
lines changed

3 files changed

+15
-17
lines changed

examples/type_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class Foo:
1111
in_: numpy.int_
1212
inc: numpy.intc
1313
ui: numpy.uint
14-
fl: numpy.float_
14+
fl: numpy.float64
1515
st: numpy.str_
1616
nd: numpy.typing.NDArray[numpy.int_]
1717
ha: numpy.half
@@ -25,7 +25,7 @@ def main() -> None:
2525
in_=numpy.int_(42),
2626
inc=numpy.intc(42),
2727
ui=numpy.uint(42),
28-
fl=numpy.float_(3.14),
28+
fl=numpy.float64(3.14),
2929
st=numpy.str_("numpy str"),
3030
nd=numpy.array([1, 2, 3]),
3131
ha=numpy.half(3.14),

pyproject.toml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ tomli = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional
3434
tomli-w = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional = true }
3535
pyyaml = { version = "*", markers = "extra == 'yaml' or extra == 'all'", optional = true }
3636
numpy = [
37-
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true },
38-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true },
39-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true },
40-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true },
41-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.13' and (extra == 'numpy' or extra == 'all')", optional = true },
37+
{ version = ">1.21.0,<3.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true },
38+
{ version = ">1.22.0,<3.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true },
39+
{ version = ">1.22.0,<3.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true },
40+
{ version = ">1.22.0,<3.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true },
41+
{ version = ">1.22.0,<3.0.0", markers = "python_version ~= '3.13' and (extra == 'numpy' or extra == 'all')", optional = true },
4242
]
4343
jaxtyping = { version = "<0.3.0", markers = "extra == 'jaxtyping' or extra == 'all'", optional = true }
4444
orjson = { version = "*", markers = "extra == 'orjson' or extra == 'all'", optional = true }
@@ -52,11 +52,11 @@ tomli = { version = "*", markers = "python_version <= '3.11.0'" }
5252
tomli-w = "*"
5353
msgpack = "*"
5454
numpy = [
55-
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0'" },
56-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10'" },
57-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11'" },
58-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12'" },
59-
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.13'" },
55+
{ version = ">1.21.0,<3.0.0", markers = "python_version ~= '3.9.0'" },
56+
{ version = ">1.22.0,<3.0.0", markers = "python_version ~= '3.10'" },
57+
{ version = ">1.22.0,<3.0.0", markers = "python_version ~= '3.11'" },
58+
{ version = ">1.22.0,<3.0.0", markers = "python_version ~= '3.12'" },
59+
{ version = ">1.22.0,<3.0.0", markers = "python_version ~= '3.13'" },
6060
]
6161
mypy = "==1.14.0"
6262
pytest = "*"
@@ -163,6 +163,7 @@ select = [
163163
"F", # pyflakes
164164
"C", # flake8-comprehensions
165165
"B", # flake8-bugbear
166+
"NPY201", # numpy2-deprecation
166167
]
167168
ignore = ["B904"]
168169
line-length = 100

tests/test_numpy.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ class NumpyDate:
9292

9393
@serde.serde(**opt)
9494
class NumpyJaxtyping:
95-
float_: jaxtyping.Float[np.ndarray, "2 2"] # noqa: F722
9695
float16: jaxtyping.Float16[np.ndarray, "2 2"] # noqa: F722
9796
float32: jaxtyping.Float32[np.ndarray, "2 2"] # noqa: F722
9897
float64: jaxtyping.Float64[np.ndarray, "2 2"] # noqa: F722
@@ -111,8 +110,7 @@ class NumpyJaxtyping:
111110

112111
def __eq__(self, other):
113112
return (
114-
(self.float_ == other.float_).all()
115-
and (self.float16 == other.float16).all()
113+
(self.float16 == other.float16).all()
116114
and (self.float32 == other.float32).all()
117115
and (self.float64 == other.float64).all()
118116
and (self.inexact == other.inexact).all()
@@ -130,11 +128,10 @@ def __eq__(self, other):
130128
)
131129

132130
jaxtyping_test = NumpyJaxtyping(
133-
float_=np.array([[1, 2], [3, 4]], dtype=np.float_),
134131
float16=np.array([[5, 6], [7, 8]], dtype=np.float16),
135132
float32=np.array([[9, 10], [11, 12]], dtype=np.float32),
136133
float64=np.array([[13, 14], [15, 16]], dtype=np.float64),
137-
inexact=np.array([[17, 18], [19, 20]], dtype=np.float_),
134+
inexact=np.array([[17, 18], [19, 20]], dtype=np.float64),
138135
int_=np.array([[21, 22], [23, 24]], dtype=np.int_),
139136
int8=np.array([[25, 26], [27, 28]], dtype=np.int8),
140137
int16=np.array([[29, 30], [31, 32]], dtype=np.int16),

0 commit comments

Comments
 (0)