2727
2828from causalpy .custom_exceptions import BadIndexException
2929from causalpy .plot_utils import get_hdi_to_df , plot_xY
30- from causalpy .pymc_models import (
31- BayesianBasisExpansionTimeSeries ,
32- PyMCModel ,
33- StateSpaceTimeSeries ,
34- )
30+ from causalpy .pymc_models import PyMCModel
3531from causalpy .utils import round_num
3632
3733from .base import BaseExperiment
@@ -202,27 +198,15 @@ def __init__(
202198 )
203199
204200 # fit the model to the observed (pre-intervention) data
201+ # All PyMC models now accept xr.DataArray with consistent API
205202 if isinstance (self .model , PyMCModel ):
206- is_bsts_like = isinstance (
207- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
208- )
209-
210- if is_bsts_like :
211- # BSTS/StateSpace models expect numpy arrays and datetime coords
212- X_fit = self .pre_X .values if self .pre_X .shape [1 ] > 0 else None # type: ignore[attr-defined]
213- y_fit = self .pre_y .isel (treated_units = 0 ).values # type: ignore[attr-defined]
214- pre_coords : dict [str , Any ] = {"datetime_index" : self .datapre .index }
215- if X_fit is not None :
216- pre_coords ["coeffs" ] = list (self .labels )
217- self .model .fit (X = X_fit , y = y_fit , coords = pre_coords )
218- else :
219- # General PyMC models expect xarray with treated_units
220- COORDS = {
221- "coeffs" : self .labels ,
222- "obs_ind" : np .arange (self .pre_X .shape [0 ]),
223- "treated_units" : ["unit_0" ],
224- }
225- self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
203+ COORDS : dict [str , Any ] = {
204+ "coeffs" : self .labels ,
205+ "obs_ind" : np .arange (self .pre_X .shape [0 ]),
206+ "treated_units" : ["unit_0" ],
207+ "datetime_index" : self .datapre .index , # For time series models
208+ }
209+ self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
226210 elif isinstance (self .model , RegressorMixin ):
227211 # For OLS models, use 1D y data
228212 self .model .fit (X = self .pre_X , y = self .pre_y .isel (treated_units = 0 ))
@@ -231,85 +215,28 @@ def __init__(
231215
232216 # score the goodness of fit to the pre-intervention data
233217 if isinstance (self .model , PyMCModel ):
234- is_bsts_like = isinstance (
235- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
236- )
237- if is_bsts_like :
238- X_score = self .pre_X .values if self .pre_X .shape [1 ] > 0 else None # type: ignore[attr-defined]
239- y_score = self .pre_y .isel (treated_units = 0 ).values # type: ignore[attr-defined]
240- score_coords : dict [str , Any ] = {"datetime_index" : self .datapre .index }
241- if X_score is not None :
242- score_coords ["coeffs" ] = list (self .labels )
243- self .score = self .model .score (X = X_score , y = y_score , coords = score_coords )
244- else :
245- self .score = self .model .score (X = self .pre_X , y = self .pre_y )
218+ self .score = self .model .score (X = self .pre_X , y = self .pre_y )
246219 elif isinstance (self .model , RegressorMixin ):
247220 self .score = self .model .score (
248221 X = self .pre_X , y = self .pre_y .isel (treated_units = 0 )
249222 )
250223
251224 # get the model predictions of the observed (pre-intervention) data
252225 if isinstance (self .model , PyMCModel ):
253- is_bsts_like = isinstance (
254- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
255- )
256- if is_bsts_like :
257- X_pre_predict = self .pre_X .values if self .pre_X .shape [1 ] > 0 else None # type: ignore[attr-defined]
258- pre_pred_coords : dict [str , Any ] = {"datetime_index" : self .datapre .index }
259- self .pre_pred = self .model .predict (
260- X = X_pre_predict , coords = pre_pred_coords
261- )
262- if not isinstance (self .pre_pred , az .InferenceData ):
263- self .pre_pred = az .InferenceData (posterior_predictive = self .pre_pred )
264- else :
265- self .pre_pred = self .model .predict (X = self .pre_X )
226+ self .pre_pred = self .model .predict (X = self .pre_X )
266227 elif isinstance (self .model , RegressorMixin ):
267228 self .pre_pred = self .model .predict (X = self .pre_X )
268229
269230 # calculate the counterfactual (post period)
270231 if isinstance (self .model , PyMCModel ):
271- is_bsts_like = isinstance (
272- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
273- )
274- if is_bsts_like :
275- X_post_predict = (
276- self .post_X .values if self .post_X .shape [1 ] > 0 else None # type: ignore[attr-defined]
277- )
278- post_pred_coords : dict [str , Any ] = {
279- "datetime_index" : self .datapost .index
280- }
281- self .post_pred = self .model .predict (
282- X = X_post_predict , coords = post_pred_coords , out_of_sample = True
283- )
284- if not isinstance (self .post_pred , az .InferenceData ):
285- self .post_pred = az .InferenceData (
286- posterior_predictive = self .post_pred
287- )
288- else :
289- self .post_pred = self .model .predict (X = self .post_X )
232+ self .post_pred = self .model .predict (X = self .post_X , out_of_sample = True )
290233 elif isinstance (self .model , RegressorMixin ):
291234 self .post_pred = self .model .predict (X = self .post_X )
292235
293- # calculate impact - use appropriate y data format for each model type
236+ # calculate impact - all PyMC models now use 2D data with treated_units
294237 if isinstance (self .model , PyMCModel ):
295- is_bsts_like = isinstance (
296- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
297- )
298- if is_bsts_like :
299- pre_y_for_impact = self .pre_y .isel (treated_units = 0 )
300- post_y_for_impact = self .post_y .isel (treated_units = 0 )
301- self .pre_impact = self .model .calculate_impact (
302- pre_y_for_impact , self .pre_pred
303- )
304- self .post_impact = self .model .calculate_impact (
305- post_y_for_impact , self .post_pred
306- )
307- else :
308- # PyMC models with treated_units use 2D data
309- self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
310- self .post_impact = self .model .calculate_impact (
311- self .post_y , self .post_pred
312- )
238+ self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
239+ self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
313240 elif isinstance (self .model , RegressorMixin ):
314241 # SKL models work with 1D data
315242 self .pre_impact = self .model .calculate_impact (
0 commit comments