|
14 | 14 |
|
15 | 15 | from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic |
16 | 16 | from jax.interpreters.pxla import xla_pmap_p |
17 | | -import jax.numpy as jnp |
18 | 17 |
|
19 | 18 |
|
20 | 19 | def eval_provenance(fn, **kwargs): |
@@ -53,18 +52,11 @@ def eval_provenance(fn, **kwargs): |
53 | 52 | # get provenances of flatten kwargs |
54 | 53 | aval_kwargs = {} |
55 | 54 | for n, v in kwargs.items(): |
56 | | - aval = jax.ShapeDtypeStruct((), jnp.bool_, {"provenance": frozenset({n})}) |
57 | | - aval_kwargs[n] = jax.tree.map(lambda _: aval, v) |
58 | | - aval_args, _ = jax.tree.flatten(((), aval_kwargs)) |
59 | | - provenance_inputs = jax.tree.map(lambda x: x.named_shape["provenance"], aval_args) |
| 55 | + aval_kwargs[n] = jax.tree.map(lambda _: frozenset({n}), v) |
| 56 | + provenance_inputs, _ = jax.tree.flatten(((), aval_kwargs)) |
60 | 57 |
|
61 | 58 | provenance_outputs = track_deps_jaxpr(jaxpr, provenance_inputs) |
62 | | - out_flat = [] |
63 | | - for v, p in zip(avals_out, provenance_outputs): |
64 | | - val = jax.ShapeDtypeStruct(jnp.shape(v), jnp.result_type(v), {"provenance": p}) |
65 | | - out_flat.append(val) |
66 | | - out = jax.tree.unflatten(out_tree(), out_flat) |
67 | | - return jax.tree.map(lambda x: x.named_shape["provenance"], out) |
| 59 | + return jax.tree.unflatten(out_tree(), provenance_outputs) |
68 | 60 |
|
69 | 61 |
|
70 | 62 | def track_deps_jaxpr(jaxpr, provenance_inputs): |
|
0 commit comments