Skip to content

Commit 3e41320

Browse files
authored
simplify the provenance logic to prepare for the removal of jax named_shape (#1837)
1 parent f6eb6ce commit 3e41320

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

numpyro/ops/provenance.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
1616
from jax.interpreters.pxla import xla_pmap_p
17-
import jax.numpy as jnp
1817

1918

2019
def eval_provenance(fn, **kwargs):
@@ -53,18 +52,11 @@ def eval_provenance(fn, **kwargs):
5352
# get provenances of flatten kwargs
5453
aval_kwargs = {}
5554
for n, v in kwargs.items():
56-
aval = jax.ShapeDtypeStruct((), jnp.bool_, {"provenance": frozenset({n})})
57-
aval_kwargs[n] = jax.tree.map(lambda _: aval, v)
58-
aval_args, _ = jax.tree.flatten(((), aval_kwargs))
59-
provenance_inputs = jax.tree.map(lambda x: x.named_shape["provenance"], aval_args)
55+
aval_kwargs[n] = jax.tree.map(lambda _: frozenset({n}), v)
56+
provenance_inputs, _ = jax.tree.flatten(((), aval_kwargs))
6057

6158
provenance_outputs = track_deps_jaxpr(jaxpr, provenance_inputs)
62-
out_flat = []
63-
for v, p in zip(avals_out, provenance_outputs):
64-
val = jax.ShapeDtypeStruct(jnp.shape(v), jnp.result_type(v), {"provenance": p})
65-
out_flat.append(val)
66-
out = jax.tree.unflatten(out_tree(), out_flat)
67-
return jax.tree.map(lambda x: x.named_shape["provenance"], out)
59+
return jax.tree.unflatten(out_tree(), provenance_outputs)
6860

6961

7062
def track_deps_jaxpr(jaxpr, provenance_inputs):

0 commit comments

Comments
 (0)