Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix progress bar error when nested CompoundStep samplers are assigned #7730

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Mar 20, 2025

#7721 reports an error in the presence of nested CompoundStep. Here's a prettier version of what pymc gives for the example in that issue:

CompoundStep
├─CompoundStep
│   ├─ Metropolis: [a]
│   ├─ Metropolis: [b]
│   └─Metropolis: [c]
└─NUTS: [d]

So there are 4 steps, but there's a compound step on the outside and on the inside. At each step, we get a flat list of 4 dictionaries holding statistics for each step. Currently, the logic for updating the progress bars makes the assumption that the list of step statistics returned at each step matches the list of step samplers. It was assumed that, if there is a compound step, there should only be one, so it can zip over the steps. Here is the display stat update for CompoundStep:

            for step_stat, update_fn in zip(step_stats, update_fns):
                displayed_stats = update_fn(displayed_stats, step_stat, chain_idx)

The problem is that if there's a nested structure, one of the udpate_fns will do this loop again. Since step_stats does not have the same nested structure, it ends up iterating over dictionary keys and raising the error.

Open to suggestions on how to proceed, because the solution isn't obvious. Opened this as a draft PR to address the problem ASAP, since its been reported 3 times now.

Description

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7730.org.readthedocs.build/en/7730/

Copy link

codecov bot commented Mar 20, 2025

Codecov Report

Attention: Patch coverage is 97.29730% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.82%. Comparing base (af81955) to head (a8af2e8).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pymc/step_methods/compound.py 92.30% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7730      +/-   ##
==========================================
- Coverage   92.82%   92.82%   -0.01%     
==========================================
  Files         107      107              
  Lines       18324    18322       -2     
==========================================
- Hits        17010    17007       -3     
- Misses       1314     1315       +1     
Files with missing lines Coverage Δ
pymc/step_methods/hmc/nuts.py 97.68% <100.00%> (-0.02%) ⬇️
pymc/step_methods/metropolis.py 93.20% <100.00%> (-0.03%) ⬇️
pymc/step_methods/slicer.py 97.32% <100.00%> (ø)
pymc/step_methods/compound.py 97.47% <92.30%> (-0.42%) ⬇️
🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Member

Can we have a smoke test that tries a bunch of step samplers for like 10 tune, and which also acts as a regression test for the issue?

@github-actions github-actions bot added the bug label Mar 23, 2025
@jessegrabowski
Copy link
Member Author

I added a solution to the bug, but I think it might be over-engineered. Basically the problem is that once we're inside the sampling loop, we don't have information about which step generated which stats. I assumed the stats could just be paired up with step samplers positionally, but that assumption is broken when there are nested CompoundSteps -- the nesting structure requires that we know more about which stats came from where.

My solution was to add a step_id to steps, which is created at sample time by a step_id_generator function (just itertools.count() by default). The step_id is stored in a meta_info statistic that is not collected or stored as a real sampling statistic. CompoundSteps then choose which update function to apply based on the step_id.

Maybe I'm missing a simpler solution? Hoping this one at least starts a convo to get to somewhere better.

@ricardoV94
Copy link
Member

I don't quite get the original problem to know if the solution is reasonable. Sounds to me like the CompoundStep should aggregate the stats for display, so the outer CompoundStep would only see two 2 entries , from NUTS and the inner CompoundStep?

@jessegrabowski
Copy link
Member Author

The update to the displayed statistics is done here. The progress manager has access to the return from iter_sample e.g. here. This is the current MCMC point (not helpful) and a "stats" dictionary. The stats dictionary is always a flat list of dictionaries with length equal to the number of assigned samplers. At sampling time, the ProgressManager only ever sees a single step (the outer-most CompoundStep, or the joint BlockedStep over everything). The hierarchical relationships between steps (if any) is destroyed before sampling begins here

Currently, the logic for a CompoundStep is to loop over the steps it contains, and apply each stats update function. So the

If the list of steps is "flat", for example there is only one joint sampler (e.g. NUTS) or if every variable has its own independent step (e.g. Metropolis with blocking=False), this logic works, because the list step methods held by the outer-most CompoundStep will be aligned with the flat stats lists. See here for where this happens.

If, however, the CompoundStep itself holds another CompoundStep, this looping update will be triggered again, and it will try to iterate over stat_dict as if it were the flat list of dicts stats, and we get an error (because it starts iterating over the dictionary keys).

@ricardoV94
Copy link
Member

Where does this flattening of stats happen?

@jessegrabowski
Copy link
Member Author

I think it's here, due to the use of .extend

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
2 participants