diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index fcf42fdf8c..877814cd61 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -13,9 +13,11 @@ # limitations under the License. from collections.abc import Sequence -from pytensor import Variable +from pytensor import Variable, clone_replace from pytensor.graph import ancestors +from pytensor.graph.fg import FunctionGraph +from pymc.data import MinibatchOp from pymc.model.core import Model from pymc.model.fgraph import ( ModelObservedRV, @@ -58,3 +60,25 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l else: vars_seq = (vars,) return [model[var] if isinstance(var, str) else var for var in vars_seq] + + +def remove_minibatched_nodes(model: Model) -> Model: + """Remove all uses of pm.Minibatch in the Model.""" + fgraph, _ = fgraph_from_model(model) + + replacements = {} + for var in fgraph.apply_nodes: + if isinstance(var.op, MinibatchOp): + for inp, out in zip(var.inputs, var.outputs): + replacements[out] = inp + + old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined] + # Using `rebuild_strict=False` means all coords, names, and dim information is lost + # So we need to restore it from the old fgraph + new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] + for old_out, new_out in zip(old_outs, new_outs): + new_out.name = old_out.name + fgraph = FunctionGraph(outputs=new_outs, clone=False) + fgraph._coords = old_coords # type: ignore[attr-defined] + fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] + return model_from_fgraph(fgraph, mutate_fgraph=True) diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index 25bf2324ec..856fbf0b2b 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + import pymc as pm -from pymc.model.transform.basic import prune_vars_detached_from_observed +from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes def test_prune_vars_detached_from_observed(): @@ -30,3 +32,20 @@ def test_prune_vars_detached_from_observed(): assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} pruned_m = prune_vars_detached_from_observed(m) assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} + + +def test_remove_minibatches(): + data_size = 100 + data = np.zeros((data_size,)) + batch_size = 10 + with pm.Model(coords={"d": range(5)}) as m1: + mb = pm.Minibatch(data, batch_size=batch_size) + mu = pm.Normal("mu", dims="d") + x = pm.Normal("x") + y = pm.Normal("y", x, observed=mb, total_size=100) + + m2 = remove_minibatched_nodes(m1) + assert m1.y.shape[0].eval() == batch_size + assert m2.y.shape[0].eval() == data_size + assert m1.coords == m2.coords + assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval()