-
Notifications
You must be signed in to change notification settings - Fork 181
Open
Description
problems
cloudpickle works for jax.jit functions but a visual inspection of the cloudpickle contents shows there's a lurking error message
challenges
not sure if this belongs in cloudpickle or jax
is this my bad? I was hopeful we could just use the string jaxpr in utf8, it's more human readable, but I don't know how to regenerate a PjitFunction from a jaxpr
opportunities
a fix could reduce the size of cloudpickled jax.jit functions
def test_jax_cloudpickle():
def jnp_func(x):
return jax.numpy.sin(jax.numpy.cos(x))
jitted1 = jax.jit(jnp_func)
del jnp_func # this to ensure jitted2 can't cheat by recompiling jnp_func within a session
assert "jnp_func" not in locals(), "failed to remove jnp_func"
jitted1_buf = cloudpickle.dumps(jitted1)
rprint(jitted1_buf)
jitted2 = cloudpickle.loads(jitted1_buf)
assert jitted1(0.3) == jitted2(0.3), "weird"
assert b"TRACEBACK" not in jitted1_buf, "error message in cloudpickle of jax.jit"
test_jax_cloudpickle()
Could JAX_TRACEBACK_FILTERING=
be greppable ?
thank you for making cloudpickle
Metadata
Metadata
Assignees
Labels
No labels