Skip to content

Commit 86cfda9

Browse files
committed
Add jaxtyping as an optional dependency, make sure typ is jaxtyping and not numpy
1 parent edea8f8 commit 86cfda9

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ numpy = [
3838
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true },
3939
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true },
4040
]
41+
jaxtyping = { version = "*", markers = "extra == 'jaxtyping' or extra == 'all'", optional = true }
4142
orjson = { version = "*", markers = "extra == 'orjson' or extra == 'all'", optional = true }
4243
plum-dispatch = ">=2,<2.3"
4344
beartype = ">=0.18.4"

serde/numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def is_numpy_jaxtyping(typ) -> bool:
7878
origin = get_origin(typ)
7979
if origin is not None:
8080
typ = origin
81-
return issubclass(typ, np.ndarray)
81+
return typ is not np.ndarray and issubclass(typ, np.ndarray)
8282
except TypeError:
8383
return False
8484

0 commit comments

Comments
 (0)