|
12 | 12 | make_weighted_distribution,
|
13 | 13 | GP_KERNELS,
|
14 | 14 | )
|
15 |
| -from bambi.families.multivariate import MultivariateFamily, Multinomial, DirichletMultinomial |
| 15 | +from bambi.families.multivariate import MultivariateFamily |
16 | 16 | from bambi.families.univariate import Categorical, Cumulative, StoppingRatio
|
17 | 17 | from bambi.priors import Prior
|
18 | 18 |
|
@@ -234,22 +234,16 @@ def build(self, pymc_backend, bmb_model):
|
234 | 234 | # Auxiliary parameters and data
|
235 | 235 | kwargs = {"observed": data, "dims": ("__obs__",)}
|
236 | 236 |
|
237 |
| - if isinstance( |
238 |
| - self.family, |
239 |
| - ( |
240 |
| - MultivariateFamily, |
241 |
| - Categorical, |
242 |
| - Cumulative, |
243 |
| - StoppingRatio, |
244 |
| - Multinomial, |
245 |
| - DirichletMultinomial, |
246 |
| - ), |
247 |
| - ): |
| 237 | + if isinstance(self.family, (MultivariateFamily, Categorical, Cumulative, StoppingRatio)): |
248 | 238 | response_term = bmb_model.response_component.term
|
249 | 239 | response_name = response_term.alias or response_term.name
|
250 | 240 | dim_name = response_name + "_dim"
|
251 | 241 | pymc_backend.model.add_coords({dim_name: response_term.levels})
|
252 | 242 | dims = ("__obs__", dim_name)
|
| 243 | + |
| 244 | + # For multivariate families, the outcome variable has two dimensions too. |
| 245 | + if isinstance(self.family, MultivariateFamily): |
| 246 | + kwargs["dims"] = dims |
253 | 247 | else:
|
254 | 248 | dims = ("__obs__",)
|
255 | 249 |
|
@@ -447,7 +441,7 @@ def __init__(self, term):
|
447 | 441 | if self.term.by_levels is not None:
|
448 | 442 | self.coords[f"{self.term.alias}_by"] = self.coords.pop(f"{self.term.name}_by")
|
449 | 443 |
|
450 |
| - def build(self): |
| 444 | + def build(self, spec): |
451 | 445 | # Get the name of the term
|
452 | 446 | label = self.name
|
453 | 447 |
|
@@ -507,6 +501,19 @@ def build(self):
|
507 | 501 | phi = phi.eval()
|
508 | 502 |
|
509 | 503 | # Build weights coefficient
|
| 504 | + # Handle the case where the outcome is multivariate |
| 505 | + if isinstance(spec.family, (MultivariateFamily, Categorical)): |
| 506 | + # Append the dims of the response variables to the coefficient and contribution dims |
| 507 | + # In general: |
| 508 | + # coeff_dims: ('weights_dim', ) -> ('weights_dim', f'{response}_dim') |
| 509 | + # contribution_dims: ('__obs__', ) -> ('__obs__', f'{response}_dim') |
| 510 | + response_dims = tuple(spec.response_component.term.coords) |
| 511 | + coeff_dims = coeff_dims + response_dims |
| 512 | + contribution_dims = contribution_dims + response_dims |
| 513 | + |
| 514 | + # Append a dimension to sqrt_psd: ('weights_dim', ) -> ('weights_dim', 1) |
| 515 | + sqrt_psd = sqrt_psd[:, np.newaxis] |
| 516 | + |
510 | 517 | if self.term.centered:
|
511 | 518 | coeffs = pm.Normal(f"{label}_weights", sigma=sqrt_psd, dims=coeff_dims)
|
512 | 519 | else:
|
|
0 commit comments