Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y' #321

Open
datainsight1 opened this issue May 10, 2024 · 15 comments

Comments

@datainsight1
Copy link

TypeError Traceback (most recent call last)
Cell In[9], line 4
2 number_warmup=100
3 number_samples=100
----> 4 mmm.fit(
5 media=media_data_train,
6 media_prior=costs,
7 target=target_train,
8 extra_features=extra_features_train,
9 number_warmup=number_warmup,
10 number_samples=number_samples,
11 seed=SEED)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/lightweight_mmm.py:363, in LightweightMMM.fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed)
353 kernel = numpyro.infer.NUTS(
354 model=self._model_function,
355 target_accept_prob=target_accept_prob,
356 init_strategy=init_strategy)
358 mcmc = numpyro.infer.MCMC(
359 sampler=kernel,
360 num_warmup=number_warmup,
361 num_samples=number_samples,
362 num_chains=number_chains)
--> 363 mcmc.run(
364 rng_key=jax.random.PRNGKey(seed),
365 media_data=jnp.array(media),
366 extra_features=extra_features,
367 target_data=jnp.array(target),
368 media_prior=jnp.array(media_prior),
369 degrees_seasonality=degrees_seasonality,
370 frequency=seasonality_frequency,
371 transform_function=self._model_transform_function,
372 weekday_seasonality=weekday_seasonality,
373 custom_priors=custom_priors)
375 self.custom_priors = custom_priors
376 if media_names is not None:

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:638, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
636 else:
637 if self.chain_method == "sequential":
--> 638 states, last_state = _laxmap(partial_map_fn, map_args)
639 elif self.chain_method == "parallel":
640 states, last_state = pmap(partial_map_fn)(map_args)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:166, in _laxmap(f, xs)
164 for i in range(n):
165 x = jit(_get_value_from_index)(xs, i)
--> 166 ys.append(f(x))
168 return tree_map(lambda *args: jnp.stack(args), *ys)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:416, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
414 # Check if _sample_fn is None, then we need to initialize the sampler.
415 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 416 new_init_state = self.sampler.init(
417 rng_key,
418 self.num_warmup,
419 init_params,
420 model_args=args,
421 model_kwargs=kwargs,
422 )
423 init_state = new_init_state if init_state is None else init_state
424 sample_fn, postprocess_fn = self._get_cached_fns()

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/hmc.py:713, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
708 # vectorized
709 else:
710 rng_key, rng_key_init_model = jnp.swapaxes(
711 vmap(random.split)(rng_key), 0, 1
712 )
--> 713 init_params = self._init_state(
714 rng_key_init_model, model_args, model_kwargs, init_params
715 )
716 if self._potential_fn and init_params is None:
717 raise ValueError(
718 "Valid value of init_params must be provided with" " potential_fn."
719 )

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/hmc.py:657, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
652 (
653 new_init_params,
654 potential_fn,
655 postprocess_fn,
656 model_trace,
--> 657 ) = initialize_model(
658 rng_key,
659 self._model,
660 dynamic_args=True,
661 init_strategy=self._init_strategy,
662 model_args=model_args,
663 model_kwargs=model_kwargs,
664 forward_mode_differentiation=self._forward_mode_differentiation,
665 )
666 if init_params is None:
667 init_params = new_init_params

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/util.py:656, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
646 model_kwargs = {} if model_kwargs is None else model_kwargs
647 substituted_model = substitute(
648 seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
649 substitute_fn=init_strategy,
650 )
651 (
652 inv_transforms,
653 replay_model,
654 has_enumerate_support,
655 model_trace,
--> 656 ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
657 # substitute param sites from model_trace to model so
658 # we don't need to generate again parameters of numpyro.module
659 model = substitute(
660 model,
661 data={
(...)
665 },
666 )

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/util.py:450, in _get_model_transforms(model, model_args, model_kwargs)
448 def _get_model_transforms(model, model_args=(), model_kwargs=None):
449 model_kwargs = {} if model_kwargs is None else model_kwargs
--> 450 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
451 inv_transforms = {}
452 # model code may need to be replayed in the presence of deterministic sites

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: OrderedDict containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/models.py:385, in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features)
380 elif transform_function == "carryover" and not transform_kwargs:
381 transform_kwargs = {"number_lags": 13 * 7}
383 media_transformed = numpyro.deterministic(
384 name="media_transformed",
--> 385 value=transform_function(media_data,
386 custom_priors=custom_priors,
387 **transform_kwargs if transform_kwargs else {}))
388 seasonality = media_transforms.calculate_seasonality(
389 number_periods=data_size,
390 degrees=degrees_seasonality,
391 frequency=frequency,
392 gamma_seasonality=gamma_seasonality)
393 # For national model's case

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/models.py:280, in transform_carryover(media_data, custom_priors, number_lags)
278 if media_data.ndim == 3:
279 exponent = jnp.expand_dims(exponent, axis=-1)
--> 280 return media_transforms.apply_exponent_safe(data=carryover, exponent=exponent)

[... skipping hidden 11 frame]

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/media_transforms.py:189, in apply_exponent_safe(data, exponent)
172 @jax.jit
173 def apply_exponent_safe(
174 data: jnp.ndarray,
175 exponent: jnp.ndarray,
176 ) -> jnp.ndarray:
177 """Applies an exponent to given data in a gradient safe way.
178
179 More info on the double jnp.where can be found:
(...)
187 The result of the exponent operation with the inputs provided.
188 """
--> 189 exponent_safe = jnp.where(condition=(data == 0), x=1, y=data) ** exponent
190 return jnp.where(condition=(data == 0), x=0, y=exponent_safe)

TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y'

@ShirleyChai730
Copy link

Hi, I had the same issue, have you resolved it?

@datainsight1
Copy link
Author

HI @ShirleyChai730 : I haven't yet been able to resolve the above issue.

@Munger245
Copy link

Hi @ShirleyChai730
This is probably due to an update of the jax library.
I explicitly installed jax and jaxlib version 0.4.20 and it works fine in my local.

@datainsight1
Copy link
Author

Thank you @Munger245 . It works.

@ShirleyChai730
Copy link

Hi @ShirleyChai730 This is probably due to an update of the jax library. I explicitly installed jax and jaxlib version 0.4.20 and it works fine in my local.

Thanks for pointing out this. I tried 0.4.20 and it still didn't work but I tried the older version 0.4.19 it works.

@rahulmisal27
Copy link

@ShirleyChai730 I am also getting this error on mac m2. What is the version of lightweight_mmm that worked on your machine? Can you please share requirement file here with python version?

@datainsight1
Copy link
Author

@rahulmisal27 : I am using the latest version of lightweight mmm and it works.

@bristobal
Copy link

I tried installing jax and jaxlib 0.4.20 and have the same error, how did you fix it? @datainsight1

@jamesvrt
Copy link

jamesvrt commented Jun 27, 2024

In a fresh Python 3.10 environment I needed to fix these versions to get things working:

jax==0.4.20 jaxlib==0.4.20 scipy==1.12.0

@ezjsiwu
Copy link

ezjsiwu commented Jul 2, 2024

hi there! im running into the same error with python 3.11 environment.. Anyone has figured out which version of jax is appropriate for this env?

@8-u8
Copy link

8-u8 commented Jul 11, 2024

Hi, I have same issue.

[7/13/24 edit]
Thanks for @jamesvrt, it worked in my environment (pipenv virtual environment Python 3.10)!

@rora00
Copy link

rora00 commented Jul 20, 2024

I had the same error message and installing jax and jaxlib versions 0.4.20 did not work for me. I have since fixed it and i'll list below the steps I took in case anyone has the same issue. Firstly, I created a python virtual environment using Anaconda with python version 3.10.14 as that's the latest version we know that works according to lightweight_mmm/setup.py. Secondly, I checked the lightweight_mmm/requirements/requirements.txt file to find the package versions listed in there which say that jax and jaxlib have to be versions 0.3.18 or higher. Apperantly, this version does not even exist, so I have used 0.4.18 instead. The final error I was facing was with the version of numpyro so I've once again used the version listed in requirements.txt file and installed version 0.9.2. The bit of code that does all this is:
%pip install jax==0.4.18 jaxlib==0.4.18 numpyro==0.9.2. Finally, I am using the latest version of lightweight_mmm 0.1.9. You can check the versions of your packages by running %pip show jax jaxlib numpyro lightweight_mmm.

@lsypro
Copy link

lsypro commented Aug 29, 2024

I encountered the same problem. My python version is 3.11.5. Finally, I followed the instructions of the two issues and installed the following versions:

seaborn==0.11.1
scipy==1.12.0
numpy==1.26.0
pyarrow==14.0.0
jax==0.4.18
jaxlib==0.4.18
numpyro==0.11.0
lightweight-mmm==0.1.9

This is useful for me!

@AdeuAndreu
Copy link

AdeuAndreu commented Oct 23, 2024

I am using python 3.10. In my case i also have to update the numpyro library to make it work. Packages updated below:

scipy==1.12.0
jax==0.4.19
jaxlib==0.4.19
numpyro==0.13.2
lightweight-mmm==0.1.9

@anudanda
Copy link

I encountered the same problem. My python version is 3.11.5. Finally, I followed the instructions of the two issues and installed the following versions:

seaborn==0.11.1
scipy==1.12.0
numpy==1.26.0
pyarrow==14.0.0
jax==0.4.18
jaxlib==0.4.18
numpyro==0.11.0
lightweight-mmm==0.1.9

This is useful for me!

Thank you ! This set up worked for me with python 3.11.7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests