Skip to content

jax jitted functions cloudpickled work but include some error messages #537

@bionicles

Description

@bionicles

problems

cloudpickle works for jax.jit functions but a visual inspection of the cloudpickle contents shows there's a lurking error message

challenges

not sure if this belongs in cloudpickle or jax

is this my bad? I was hopeful we could just use the string jaxpr in utf8, it's more human readable, but I don't know how to regenerate a PjitFunction from a jaxpr

opportunities

a fix could reduce the size of cloudpickled jax.jit functions

def test_jax_cloudpickle():
    def jnp_func(x):
        return jax.numpy.sin(jax.numpy.cos(x))

    jitted1 = jax.jit(jnp_func)
    del jnp_func  # this to ensure jitted2 can't cheat by recompiling jnp_func within a session
    assert "jnp_func" not in locals(), "failed to remove jnp_func"
    jitted1_buf = cloudpickle.dumps(jitted1)
    rprint(jitted1_buf)
    jitted2 = cloudpickle.loads(jitted1_buf)
    assert jitted1(0.3) == jitted2(0.3), "weird"
    assert b"TRACEBACK" not in jitted1_buf, "error message in cloudpickle of jax.jit"


test_jax_cloudpickle()

Could JAX_TRACEBACK_FILTERING= be greppable ?

image

thank you for making cloudpickle

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions