Skip to content

Commit 20cc071

Browse files
authored
Merge pull request #577 from pymc-labs/cetagostini/bsts_tech_debt_v0
Refactor PyMC time series models to use xarray API
2 parents e2ae2e8 + a15e17f commit 20cc071

File tree

9 files changed

+1335
-1373
lines changed

9 files changed

+1335
-1373
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ dist/
1414
docs/build/
1515
docs/jupyter_execute/
1616
docs/source/api/generated/
17+
18+
.cursor/

causalpy/experiments/interrupted_time_series.py

Lines changed: 15 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@
2727

2828
from causalpy.custom_exceptions import BadIndexException
2929
from causalpy.plot_utils import get_hdi_to_df, plot_xY
30-
from causalpy.pymc_models import (
31-
BayesianBasisExpansionTimeSeries,
32-
PyMCModel,
33-
StateSpaceTimeSeries,
34-
)
30+
from causalpy.pymc_models import PyMCModel
3531
from causalpy.utils import round_num
3632

3733
from .base import BaseExperiment
@@ -202,27 +198,15 @@ def __init__(
202198
)
203199

204200
# fit the model to the observed (pre-intervention) data
201+
# All PyMC models now accept xr.DataArray with consistent API
205202
if isinstance(self.model, PyMCModel):
206-
is_bsts_like = isinstance(
207-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
208-
)
209-
210-
if is_bsts_like:
211-
# BSTS/StateSpace models expect numpy arrays and datetime coords
212-
X_fit = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
213-
y_fit = self.pre_y.isel(treated_units=0).values # type: ignore[attr-defined]
214-
pre_coords: dict[str, Any] = {"datetime_index": self.datapre.index}
215-
if X_fit is not None:
216-
pre_coords["coeffs"] = list(self.labels)
217-
self.model.fit(X=X_fit, y=y_fit, coords=pre_coords)
218-
else:
219-
# General PyMC models expect xarray with treated_units
220-
COORDS = {
221-
"coeffs": self.labels,
222-
"obs_ind": np.arange(self.pre_X.shape[0]),
223-
"treated_units": ["unit_0"],
224-
}
225-
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
203+
COORDS: dict[str, Any] = {
204+
"coeffs": self.labels,
205+
"obs_ind": np.arange(self.pre_X.shape[0]),
206+
"treated_units": ["unit_0"],
207+
"datetime_index": self.datapre.index, # For time series models
208+
}
209+
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
226210
elif isinstance(self.model, RegressorMixin):
227211
# For OLS models, use 1D y data
228212
self.model.fit(X=self.pre_X, y=self.pre_y.isel(treated_units=0))
@@ -231,85 +215,28 @@ def __init__(
231215

232216
# score the goodness of fit to the pre-intervention data
233217
if isinstance(self.model, PyMCModel):
234-
is_bsts_like = isinstance(
235-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
236-
)
237-
if is_bsts_like:
238-
X_score = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
239-
y_score = self.pre_y.isel(treated_units=0).values # type: ignore[attr-defined]
240-
score_coords: dict[str, Any] = {"datetime_index": self.datapre.index}
241-
if X_score is not None:
242-
score_coords["coeffs"] = list(self.labels)
243-
self.score = self.model.score(X=X_score, y=y_score, coords=score_coords)
244-
else:
245-
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
218+
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
246219
elif isinstance(self.model, RegressorMixin):
247220
self.score = self.model.score(
248221
X=self.pre_X, y=self.pre_y.isel(treated_units=0)
249222
)
250223

251224
# get the model predictions of the observed (pre-intervention) data
252225
if isinstance(self.model, PyMCModel):
253-
is_bsts_like = isinstance(
254-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
255-
)
256-
if is_bsts_like:
257-
X_pre_predict = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
258-
pre_pred_coords: dict[str, Any] = {"datetime_index": self.datapre.index}
259-
self.pre_pred = self.model.predict(
260-
X=X_pre_predict, coords=pre_pred_coords
261-
)
262-
if not isinstance(self.pre_pred, az.InferenceData):
263-
self.pre_pred = az.InferenceData(posterior_predictive=self.pre_pred)
264-
else:
265-
self.pre_pred = self.model.predict(X=self.pre_X)
226+
self.pre_pred = self.model.predict(X=self.pre_X)
266227
elif isinstance(self.model, RegressorMixin):
267228
self.pre_pred = self.model.predict(X=self.pre_X)
268229

269230
# calculate the counterfactual (post period)
270231
if isinstance(self.model, PyMCModel):
271-
is_bsts_like = isinstance(
272-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
273-
)
274-
if is_bsts_like:
275-
X_post_predict = (
276-
self.post_X.values if self.post_X.shape[1] > 0 else None # type: ignore[attr-defined]
277-
)
278-
post_pred_coords: dict[str, Any] = {
279-
"datetime_index": self.datapost.index
280-
}
281-
self.post_pred = self.model.predict(
282-
X=X_post_predict, coords=post_pred_coords, out_of_sample=True
283-
)
284-
if not isinstance(self.post_pred, az.InferenceData):
285-
self.post_pred = az.InferenceData(
286-
posterior_predictive=self.post_pred
287-
)
288-
else:
289-
self.post_pred = self.model.predict(X=self.post_X)
232+
self.post_pred = self.model.predict(X=self.post_X, out_of_sample=True)
290233
elif isinstance(self.model, RegressorMixin):
291234
self.post_pred = self.model.predict(X=self.post_X)
292235

293-
# calculate impact - use appropriate y data format for each model type
236+
# calculate impact - all PyMC models now use 2D data with treated_units
294237
if isinstance(self.model, PyMCModel):
295-
is_bsts_like = isinstance(
296-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
297-
)
298-
if is_bsts_like:
299-
pre_y_for_impact = self.pre_y.isel(treated_units=0)
300-
post_y_for_impact = self.post_y.isel(treated_units=0)
301-
self.pre_impact = self.model.calculate_impact(
302-
pre_y_for_impact, self.pre_pred
303-
)
304-
self.post_impact = self.model.calculate_impact(
305-
post_y_for_impact, self.post_pred
306-
)
307-
else:
308-
# PyMC models with treated_units use 2D data
309-
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
310-
self.post_impact = self.model.calculate_impact(
311-
self.post_y, self.post_pred
312-
)
238+
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
239+
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
313240
elif isinstance(self.model, RegressorMixin):
314241
# SKL models work with 1D data
315242
self.pre_impact = self.model.calculate_impact(

0 commit comments

Comments
 (0)