Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix progress bar error when nested CompoundStep samplers are assigned #7730

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def record(self, point, sampler_stats=None) -> None:
if sampler_stats is not None:
for data, vars in zip(self._stats, sampler_stats):
for key, val in vars.items():
# step_meta is a key used by the progress bars to track which draw came from which step instance. It
# should never be stored as a sampler statistic.
if key == "step_meta":
continue
data[key][draw_idx] = val
elif self._stats is not None:
raise ValueError("Expected sampler_stats")
Expand Down
11 changes: 10 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Functions for MCMC sampling."""

import contextlib
import itertools
import logging
import pickle
import sys
Expand Down Expand Up @@ -111,6 +112,7 @@ def instantiate_steppers(
step_kwargs: dict[str, dict] | None = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
step_id_generator: Iterator[int] | None = None,
) -> Step | list[Step]:
"""Instantiate steppers assigned to the model variables.

Expand Down Expand Up @@ -139,6 +141,9 @@ def instantiate_steppers(
if step_kwargs is None:
step_kwargs = {}

if step_id_generator is None:
step_id_generator = itertools.count()

used_keys = set()
if selected_steps:
if initial_point is None:
Expand All @@ -154,6 +159,7 @@ def instantiate_steppers(
model=model,
initial_point=initial_point,
compile_kwargs=compile_kwargs,
step_id_generator=step_id_generator,
**kwargs,
)
steps.append(step)
Expand Down Expand Up @@ -853,16 +859,19 @@ def joined_blas_limiter():
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)]

# Instantiate automatically selected steps
# Use a counter to generate a unique id for each stepper used in the model.
step_id_generator = itertools.count()
step = instantiate_steppers(
model,
steps=provided_steps,
selected_steps=selected_steps,
step_kwargs=kwargs,
initial_point=initial_points[0],
compile_kwargs=compile_kwargs,
step_id_generator=step_id_generator,
)
if isinstance(step, list):
step = CompoundStep(step)
step = CompoundStep(step, step_id_generator=step_id_generator)

if var_names is not None:
trace_vars = [v for v in model.unobserved_RVs if v.name in var_names]
Expand Down
37 changes: 32 additions & 5 deletions pymc/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Iterator
from typing import cast

import numpy as np
Expand Down Expand Up @@ -43,14 +43,25 @@ class ArrayStep(BlockedStep):
:py:func:`pymc.util.get_random_generator` for more information.
"""

def __init__(self, vars, fs, allvars=False, blocked=True, rng: RandomGenerator = None):
def __init__(
self,
vars,
fs,
allvars=False,
blocked=True,
rng: RandomGenerator = None,
step_id_generator: Iterator[int] | None = None,
):
self.vars = vars
self.fs = fs
self.allvars = allvars
self.blocked = blocked
self.rng = get_random_generator(rng)
self._step_id = next(step_id_generator) if step_id_generator else None

def step(self, point: PointType) -> tuple[PointType, StatsType]:
def step(
self, point: PointType, step_parent_id: int | None = None
) -> tuple[PointType, StatsType]:
partial_funcs_and_point: list[Callable | PointType] = [
DictToArrayBijection.mapf(x, start_point=point) for x in self.fs
]
Expand All @@ -61,6 +72,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]:
apoint = DictToArrayBijection.map(var_dict)
apoint_new, stats = self.astep(apoint, *partial_funcs_and_point)

for sts in stats:
sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id}

if not isinstance(apoint_new, RaveledVars):
# We assume that the mapping has stayed the same
apoint_new = RaveledVars(apoint_new, apoint.point_map_info)
Expand All @@ -84,7 +98,14 @@ class ArrayStepShared(BlockedStep):
and unmapping overhead as well as moving fewer variables around.
"""

def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
def __init__(
self,
vars,
shared,
blocked=True,
rng: RandomGenerator = None,
step_id_generator: Iterator[int] | None = None,
):
"""
Create the ArrayStepShared object.

Expand All @@ -103,8 +124,11 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
self.blocked = blocked
self.rng = get_random_generator(rng)
self._step_id = next(step_id_generator) if step_id_generator else None

def step(self, point: PointType) -> tuple[PointType, StatsType]:
def step(
self, point: PointType, step_parent_id: int | None = None
) -> tuple[PointType, StatsType]:
full_point = None
if self.shared:
for name, shared_var in self.shared.items():
Expand All @@ -115,6 +139,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]:
q = DictToArrayBijection.map(point)
apoint, stats = self.astep(q)

for sts in stats:
sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id}

if not isinstance(apoint, RaveledVars):
# We assume that the mapping has stayed the same
apoint = RaveledVars(apoint, q.point_map_info)
Expand Down
Loading
Loading