-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
Cloudpickle (latest, 3.3.1 at time of writing) fails to pickle jaxtyped functions because of the weakref (introduced in 0.2.35 of jaxtyping).
I have a MWE using uv with inline packages:
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "beartype",
# "cloudpickle==3.1.1",
# "jaxtyping==0.2.35",
# "numpy",
# ]
# ///
import beartype
import cloudpickle
import numpy as np
from jaxtyping import Float, jaxtyped
@jaxtyped(typechecker=beartype.beartype)
def typechecked_fn(x: Float[np.ndarray, " d"]) -> float:
return np.sum(x).item()
def main():
dumped = cloudpickle.dumps(typechecked_fn)
print(dumped)
fn = cloudpickle.loads(dumped)
print(fn(np.array([1.0, 2.0])))
if __name__ == "__main__":
main()When I run this with uv run scratch.py I get
Traceback (most recent call last):
File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 29, in <module>
main()
File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 22, in main
dumped = cloudpickle.dumps(typechecked_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1537, in dumps
cp.dump(obj)
File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1303, in dump
return super().dump(obj)
^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'weakref.ReferenceType' object
When I update jaxtyping to 0.3.2 (in the script metadata)
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "beartype",
# "cloudpickle==3.1.1",
# "jaxtyping==0.3.2",
# "numpy",
# ]
# ///
import beartype
import cloudpickle
import numpy as np
from jaxtyping import Float, jaxtyped
@jaxtyped(typechecker=beartype.beartype)
def typechecked_fn(x: Float[np.ndarray, " d"]) -> float:
return np.sum(x).item()
def main():
dumped = cloudpickle.dumps(typechecked_fn)
print(dumped)
fn = cloudpickle.loads(dumped)
print(fn(np.array([1.0, 2.0])))
if __name__ == "__main__":
main()I get the same error:
Traceback (most recent call last):
File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 29, in <module>
main()
File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 22, in main
dumped = cloudpickle.dumps(typechecked_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1537, in dumps
cp.dump(obj)
File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1303, in dump
return super().dump(obj)
^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'weakref.ReferenceType' object
This MWE works on jaxtyping 0.2.34, but my other script fails on 0.2.34 with the same error as #198.
Metadata
Metadata
Assignees
Labels
No labels