Skip to content

cloudpickle + weakref + index_variadic #332

@samuelstevens

Description

@samuelstevens

Cloudpickle (latest, 3.3.1 at time of writing) fails to pickle jaxtyped functions because of the weakref (introduced in 0.2.35 of jaxtyping).

I have a MWE using uv with inline packages:

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "beartype",
#     "cloudpickle==3.1.1",
#     "jaxtyping==0.2.35",
#     "numpy",
# ]
# ///
import beartype
import cloudpickle
import numpy as np
from jaxtyping import Float, jaxtyped


@jaxtyped(typechecker=beartype.beartype)
def typechecked_fn(x: Float[np.ndarray, " d"]) -> float:
    return np.sum(x).item()


def main():
    dumped = cloudpickle.dumps(typechecked_fn)
    print(dumped)
    fn = cloudpickle.loads(dumped)
    print(fn(np.array([1.0, 2.0])))


if __name__ == "__main__":
    main()

When I run this with uv run scratch.py I get

Traceback (most recent call last):
  File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 29, in <module>
    main()
  File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 22, in main
    dumped = cloudpickle.dumps(typechecked_fn)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1537, in dumps
    cp.dump(obj)
  File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1303, in dump
    return super().dump(obj)
           ^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'weakref.ReferenceType' object

When I update jaxtyping to 0.3.2 (in the script metadata)

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "beartype",
#     "cloudpickle==3.1.1",
#     "jaxtyping==0.3.2",
#     "numpy",
# ]
# ///
import beartype
import cloudpickle
import numpy as np
from jaxtyping import Float, jaxtyped


@jaxtyped(typechecker=beartype.beartype)
def typechecked_fn(x: Float[np.ndarray, " d"]) -> float:
    return np.sum(x).item()


def main():
    dumped = cloudpickle.dumps(typechecked_fn)
    print(dumped)
    fn = cloudpickle.loads(dumped)
    print(fn(np.array([1.0, 2.0])))


if __name__ == "__main__":
    main()

I get the same error:

Traceback (most recent call last):
  File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 29, in <module>
    main()
  File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 22, in main
    dumped = cloudpickle.dumps(typechecked_fn)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1537, in dumps
    cp.dump(obj)
  File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1303, in dump
    return super().dump(obj)
           ^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'weakref.ReferenceType' object

This MWE works on jaxtyping 0.2.34, but my other script fails on 0.2.34 with the same error as #198.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions