Skip to content

Add support for multiple BART random variables per model. #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

derekpowell
Copy link

This PR adds new functionality to support multiple BART RVs defined within a single pymc model, addressing and resolve the limitation discussed in #86.

  • added ability to support multiple BART RVs in a single model (each gets its own pgbart sampler with combined sampling handled by pymc's compound sampling functionality)
  • tests of new functionality

Example use

This addition allows for some new modeling possibilities. One example is the Bayesian Causal Forests model of Hahn et al.. This model aims to capture heterogenous treatment effects of a binary treatment $Z$ by separately capturing the prognostic predictiveness of those covariates $X$ and their interactions with $z$ (note that different $X$ may also be modeled by $\mu$ and $\tau$).

$$f(x_i, z_i) = \mu(x_i) + \tau(x_i) \cdot z_i$$

This can be implemented as:

import pymc as pm
import pymc_bart as pmb

with pm.Model() as bcf_model:
    # Prognostic function (baseline outcome) - BART1
    mu = pmb.BART("mu", X=X_prog, Y=y, m=50) 
    
    # Treatment effect function - BART2
    tau = pmb.BART("tau", X=X_trt, Y=y, m=50)
    
    # Combine the two BART functions
    # Mean function: mu + Z * tau
    mean_y = mu + Z * tau
    
    # Observation noise
    sigma = pm.HalfNormal("sigma", sigma=1.0)
    
    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mean_y, sigma=sigma, observed=y)

Hope this looks good. I also plan to offer a more detailed demo example for the pymc-examples doc repo if desired!

@aloctavodia
Copy link
Member

Thanks, this is great! Adding one example will be fantastic!

@aloctavodia
Copy link
Member

Before approving, I tested locally with a couple of examples, and they ran. But the test is failing. Funny that the order of the variables in the model affects the result. This works

        sigma = pm.HalfNormal("sigma", 1)
        mu1 = pmb.BART("mu1", X1, Y1, m=5)
        mu2 = pmb.BART("mu2", X2, Y2, m=5)

but this fails

        mu1 = pmb.BART("mu1", X1, Y1, m=5)
        mu2 = pmb.BART("mu2", X2, Y2, m=5)
        sigma = pm.HalfNormal("sigma", 1)

@@ -346,7 +364,7 @@ def resample(
new_particles.append(particles[idx].copy())
else:
new_particles.append(particles[idx])
seen.append(idx)
seen.append(int(idx))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope! Sorry about that, unnecessary since it's already int

@derekpowell
Copy link
Author

Before approving, I tested locally with a couple of examples, and they ran. But the test is failing. Funny that the order of the variables in the model affects the result. This works

        sigma = pm.HalfNormal("sigma", 1)
        mu1 = pmb.BART("mu1", X1, Y1, m=5)
        mu2 = pmb.BART("mu2", X2, Y2, m=5)

but this fails

        mu1 = pmb.BART("mu1", X1, Y1, m=5)
        mu2 = pmb.BART("mu2", X2, Y2, m=5)
        sigma = pm.HalfNormal("sigma", 1)

Hmm that's odd---I can't reproduce this behavior. I tried swapping the order of NUTS and pgbart variables in my tests and they still passed, and in my example notebook I'm working on and it ran. Or are you saying that it runs but the results are just different? To some extent I'd expect that with the sequential nature of the sampling steps but ideally they wouldn't be too different. If you can share reprex I can look into it?

@aloctavodia
Copy link
Member

Locally I am getting the same error than here, it does not run unless the bart RVs are defined last. I am away from my laptop most likely until Monday. Are you using the last pymc version?

@derekpowell
Copy link
Author

Oh, looks like I am on pymc = 5.19.1 (I used environment-dev.yml)

@derekpowell
Copy link
Author

derekpowell commented Jun 13, 2025

Alright I am able to reproduce this issue with pymc 5.23. Things work if you manually specify the compound step as in tests/test_bart.py::test_multiple_bart_variables_manual_step. Looks like the breaking change / regression (if that's what it is) comes in pymc v5.20.1.

For now, I suppose I could:

  1. Remove the failing test
  2. Add a version check in PGBART to warn that compound steps must be manually specified for versions >=5.20.1 (I can also figure out what most up-to-date release that works is). This is a bit suboptimal, not sure if this is too hacky.

I'm afraid that diagnosing what has changed on the pymc side beyond that is likely to be beyond me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants