Skip to content

Commit dd83700

Browse files
Jake VanderPlascopybara-github
authored andcommitted
Silence some pytype errors related to a JAX build refactor
This build change allows pytype to propagate annotations that it previously did not, and because of this it starts flagging existing incorrect annotations. PiperOrigin-RevId: 772480864 Change-Id: If8bd74414306e120764eaf0f9eabe725d5282680
1 parent 9b42565 commit dd83700

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

lightweight_mmm/core/transformations/lagging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def adstock_internal(
152152
adstock_value = prev_adstock * lag_weight + data
153153
return adstock_value, adstock_value# jax-ndarray
154154

155-
_, adstock_values = jax.lax.scan(
155+
_, adstock_values = jax.lax.scan(# lax-types
156156
f=adstock_internal, init=data[0, ...], xs=data[1:, ...])
157157
adstock_values = jnp.concatenate([jnp.array([data[0, ...]]), adstock_values])
158158
return jax.lax.cond(

lightweight_mmm/media_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def adstock_internal(prev_adstock: jnp.ndarray,
8484
adstock_value = prev_adstock * lag_weight + data
8585
return adstock_value, adstock_value# jax-ndarray
8686

87-
_, adstock_values = jax.lax.scan(
87+
_, adstock_values = jax.lax.scan(# lax-types
8888
f=adstock_internal, init=data[0, ...], xs=data[1:, ...])
8989
adstock_values = jnp.concatenate([jnp.array([data[0, ...]]), adstock_values])
9090
return jax.lax.cond(

0 commit comments

Comments
 (0)