Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: add got incompatible shapes for broadcasting: (58,), (54,). #309

Open
MarkusStefan opened this issue Mar 4, 2024 · 8 comments

Comments

@MarkusStefan
Copy link


TypeError Traceback (most recent call last)
in <cell line: 2>()
4 seed=SEED)
5 else:
----> 6 new_predictions = mmm.predict(media=media_scaler.transform(media_data_test),
7 extra_features=extra_features_scaler.transform(extra_features_test),
8 seed=SEED)

17 frames
/usr/local/lib/python3.10/dist-packages/lightweight_mmm/lightweight_mmm.py in predict(self, media, extra_features, media_gap, target_scaler, seed)
518 if seed is None:
519 seed = utils.get_time_seed()
--> 520 prediction = self._predict(
521 rng_key=jax.random.PRNGKey(seed=seed),
522 media_data=full_media,

[... skipping hidden 12 frame]

/usr/local/lib/python3.10/dist-packages/lightweight_mmm/lightweight_mmm.py in _predict(self, rng_key, media_data, extra_features, media_prior, degrees_seasonality, frequency, transform_function, weekday_seasonality, model, posterior_samples, custom_priors)
441 The predictions for the given data.
442 """
--> 443 return infer.Predictive(
444 model=model, posterior_samples=posterior_samples)(
445 rng_key=rng_key,

/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in call(self, rng_key, *args, **kwargs)
1009 """
1010 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1011 return self._call_with_params(rng_key, self.params, args, kwargs)
1012 elif self.batch_ndims == 1: # batch over parameters
1013 batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0]

/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in _call_with_params(self, rng_key, params, args, kwargs)
986 )
987 model = substitute(self.model, self.params)
--> 988 return _predictive(
989 rng_key,
990 model,

/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)
823 rng_key = rng_key.reshape(batch_shape + key_shape)
824 chunk_size = num_samples if parallel else 1
--> 825 return soft_vmap(
826 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
827 )

/usr/local/lib/python3.10/dist-packages/numpyro/util.py in soft_vmap(fn, xs, batch_ndims, chunk_size)
417 fn = vmap(fn)
418
--> 419 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
420 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
421 ys = tree_map(

[... skipping hidden 12 frame]

/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in single_prediction(val)
796 )
797 else:
--> 798 model_trace = trace(
799 seed(substitute(masked_model, samples), rng_key)
800 ).get_trace(*model_args, **model_kwargs)

/usr/local/lib/python3.10/dist-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
169 :return: OrderedDict containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
173

/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107

/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107

/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107

/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107

/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107

/usr/local/lib/python3.10/dist-packages/lightweight_mmm/models.py in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features)
410 # expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5].
411 prediction = (
--> 412 intercept + coef_trend * trend ** expo_trend +
413 seasonality * coef_seasonality +
414 jnp.einsum(media_einsum, media_transformed, coef_media))

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py in op(self, *args)
741 def forward_operator_to_aval(name):
742 def op(self, *args):
--> 743 return getattr(self.aval, f"
{name}")(self, *args)
744 return op
745

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py in deferring_binary_op(self, other)
269 args = (other, self) if swap else (self, other)
270 if isinstance(other, _accepted_binop_types):
--> 271 return binary_op(*args)
272 # Note: don't use isinstance here, because we don't want to raise for
273 # subclasses, e.g. NamedTuple objects that may override operators.

[... skipping hidden 12 frame]

/usr/local/lib/python3.10/dist-packages/jax/src/numpy/ufuncs.py in fn(x1, x2)
97 def fn(x1, x2, /):
98 x1, x2 = promote_args(numpy_fn.name, x1, x2)
---> 99 return lax_fn(x1, x2) if x1.dtype != np.bool
else bool_lax_fn(x1, x2)
100 fn.qualname = f"jax.numpy.{numpy_fn.name}"
101 fn = jit(fn, inline=True)

[... skipping hidden 7 frame]

/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py in broadcasting_shape_rule(name, *avals)
1597 result_shape.append(non_1s[0])
1598 else:
-> 1599 raise TypeError(f'{name} got incompatible shapes for broadcasting: '
1600 f'{", ".join(map(str, map(tuple, shapes)))}.')
1601

TypeError: add got incompatible shapes for broadcasting: (58,), (54,).

@Pavantejapenugonda
Copy link

Even i am getting the issue, looking for the solution for it

@MarkusStefan
Copy link
Author

Installing an older version of numpyro resolved my issue !pip numpyro == 0.13.2

@masifkingpin
Copy link

I had the same problem and 0.13.2 version of numpyro was not working for me so I used the following command to install numpyro while installing mmm, matplotlib etc:

!pip install numpyro==0.13.1

@shivahari15091994
Copy link

I am also facing the same problem. Appreciate if anyone has solution for this. Thanks

@MarkusStefan
Copy link
Author

just install an older version of numpyro as stated in the comments above

@jingwg
Copy link

jingwg commented Mar 26, 2024

When i install an older version of numpyro, I have following issues with import . Any idea how to solve this?

ModuleNotFoundError Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_12544\2990551649.py in
1 import pandas as pd
----> 2 from lightweight_mmm import preprocessing, lightweight_mmm, plot, optimize_media
3 import jax.numpy as jnp
4 from sklearn.metrics import mean_absolute_percentage_error

~\Anaconda3\envs\python3\lib\site-packages\lightweight_mmm\preprocessing.py in
24 from statsmodels.stats.outliers_influence import variance_inflation_factor
25 from statsmodels.tools.tools import add_constant
---> 26 from lightweight_mmm.core import core_utils
27
28

~\Anaconda3\envs\python3\lib\site-packages\lightweight_mmm\core\core_utils.py in
20 import jax.numpy as jnp
21
---> 22 from numpyro import distributions as dist
23
24 # pylint: disable=g-import-not-at-top

~\Anaconda3\envs\python3\lib\site-packages\numpyro_init_.py in
4 import logging
5
----> 6 from numpyro import compat, diagnostics, distributions, handlers, infer, ops, optim
7 from numpyro.distributions.distribution import enable_validation, validation_enabled
8 from numpyro.infer.inspect import render_model

~\Anaconda3\envs\python3\lib\site-packages\numpyro\infer_init_.py in
3
4 from numpyro.infer.barker import BarkerMH
----> 5 from numpyro.infer.elbo import (
6 ELBO,
7 RenyiELBO,

~\Anaconda3\envs\python3\lib\site-packages\numpyro\infer\elbo.py in
23 log_density,
24 )
---> 25 from numpyro.ops.provenance import eval_provenance
26 from numpyro.util import _validate_model, check_model_guide_match, find_stack_level
27

~\Anaconda3\envs\python3\lib\site-packages\numpyro\ops\provenance.py in
6 import jax.core as core
7 from jax.experimental.pjit import pjit_p
----> 8 import jax.extend.linear_util as lu
9 from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
10 from jax.interpreters.pxla import xla_pmap_p

ModuleNotFoundError: No module named 'jax.extend.linear_util'

@AkiroSR
Copy link

AkiroSR commented Apr 3, 2024

install an older version of jax.
'jax.extend.linear_util' was removed in jax after 0.4.23 (currently in 0.4.25)

@fehiepsi
Copy link
Member

fehiepsi commented Apr 3, 2024

Sorry for the breakage. Could you try

pip install --upgrade git+https://github.com/google/lightweight_mmm.git

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants