Skip to content

Commit 1559a97

Browse files
authored
Handle multivariate responses with HSGP (#856)
* Make HSGP terms aware of multivariate responses * Make sure two dimensional outputs have two dims * Remove redundant classes from checks * Remove prints and add comments * Remove commented code
1 parent 5d772ff commit 1559a97

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

Diff for: bambi/backend/model_components.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def build(self, pymc_backend, bmb_model):
5353
self.build_intercept(bmb_model)
5454
self.build_offsets()
5555
self.build_common_terms(pymc_backend, bmb_model)
56-
self.build_hsgp_terms(pymc_backend)
56+
self.build_hsgp_terms(bmb_model, pymc_backend)
5757
self.build_group_specific_terms(pymc_backend, bmb_model)
5858

5959
def build_intercept(self, bmb_model):
@@ -109,7 +109,7 @@ def build_common_terms(self, pymc_backend, bmb_model):
109109
# Add term to linear predictor
110110
self.output += pt.dot(data, coefs)
111111

112-
def build_hsgp_terms(self, pymc_backend):
112+
def build_hsgp_terms(self, bmb_model, pymc_backend):
113113
"""Add HSGP (Hilbert-Space Gaussian Process approximation) terms to the PyMC model.
114114
115115
The linear predictor 'X @ b + Z @ u' can be augmented with non-parametric HSGP terms
@@ -120,7 +120,7 @@ def build_hsgp_terms(self, pymc_backend):
120120
for name, values in hsgp_term.coords.items():
121121
if name not in pymc_backend.model.coords:
122122
pymc_backend.model.add_coords({name: values})
123-
self.output += hsgp_term.build()
123+
self.output += hsgp_term.build(bmb_model)
124124

125125
def build_group_specific_terms(self, pymc_backend, bmb_model):
126126
"""Add group-specific (random or varying) terms to the PyMC model

Diff for: bambi/backend/terms.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
make_weighted_distribution,
1313
GP_KERNELS,
1414
)
15-
from bambi.families.multivariate import MultivariateFamily, Multinomial, DirichletMultinomial
15+
from bambi.families.multivariate import MultivariateFamily
1616
from bambi.families.univariate import Categorical, Cumulative, StoppingRatio
1717
from bambi.priors import Prior
1818

@@ -234,22 +234,16 @@ def build(self, pymc_backend, bmb_model):
234234
# Auxiliary parameters and data
235235
kwargs = {"observed": data, "dims": ("__obs__",)}
236236

237-
if isinstance(
238-
self.family,
239-
(
240-
MultivariateFamily,
241-
Categorical,
242-
Cumulative,
243-
StoppingRatio,
244-
Multinomial,
245-
DirichletMultinomial,
246-
),
247-
):
237+
if isinstance(self.family, (MultivariateFamily, Categorical, Cumulative, StoppingRatio)):
248238
response_term = bmb_model.response_component.term
249239
response_name = response_term.alias or response_term.name
250240
dim_name = response_name + "_dim"
251241
pymc_backend.model.add_coords({dim_name: response_term.levels})
252242
dims = ("__obs__", dim_name)
243+
244+
# For multivariate families, the outcome variable has two dimensions too.
245+
if isinstance(self.family, MultivariateFamily):
246+
kwargs["dims"] = dims
253247
else:
254248
dims = ("__obs__",)
255249

@@ -447,7 +441,7 @@ def __init__(self, term):
447441
if self.term.by_levels is not None:
448442
self.coords[f"{self.term.alias}_by"] = self.coords.pop(f"{self.term.name}_by")
449443

450-
def build(self):
444+
def build(self, spec):
451445
# Get the name of the term
452446
label = self.name
453447

@@ -507,6 +501,19 @@ def build(self):
507501
phi = phi.eval()
508502

509503
# Build weights coefficient
504+
# Handle the case where the outcome is multivariate
505+
if isinstance(spec.family, (MultivariateFamily, Categorical)):
506+
# Append the dims of the response variables to the coefficient and contribution dims
507+
# In general:
508+
# coeff_dims: ('weights_dim', ) -> ('weights_dim', f'{response}_dim')
509+
# contribution_dims: ('__obs__', ) -> ('__obs__', f'{response}_dim')
510+
response_dims = tuple(spec.response_component.term.coords)
511+
coeff_dims = coeff_dims + response_dims
512+
contribution_dims = contribution_dims + response_dims
513+
514+
# Append a dimension to sqrt_psd: ('weights_dim', ) -> ('weights_dim', 1)
515+
sqrt_psd = sqrt_psd[:, np.newaxis]
516+
510517
if self.term.centered:
511518
coeffs = pm.Normal(f"{label}_weights", sigma=sqrt_psd, dims=coeff_dims)
512519
else:

0 commit comments

Comments
 (0)