|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +"""An inference engine which uses Bean Machine to make |
| 7 | +inferences on optimized Bean Machine models.""" |
| 8 | + |
| 9 | +from typing import Dict, List, Set |
| 10 | + |
| 11 | +import graphviz |
| 12 | +import torch |
| 13 | + |
| 14 | +from beanmachine.ppl.compiler.fix_problems import default_skip_optimizations |
| 15 | +from beanmachine.ppl.compiler.gen_bm_python import InferenceType, to_bm_python |
| 16 | +from beanmachine.ppl.compiler.gen_bmg_graph import to_bmg_graph |
| 17 | +from beanmachine.ppl.compiler.gen_dot import to_dot |
| 18 | +from beanmachine.ppl.compiler.gen_mini import to_mini |
| 19 | +from beanmachine.ppl.compiler.runtime import BMGRuntime |
| 20 | +from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples |
| 21 | +from beanmachine.ppl.inference.utils import _verify_queries_and_observations |
| 22 | +from beanmachine.ppl.model.rv_identifier import RVIdentifier |
| 23 | + |
| 24 | + |
| 25 | +class BMInference: |
| 26 | + """ |
| 27 | + Interface to Bean Machine Inference on optimized models. |
| 28 | +
|
| 29 | + Please note that this is a highly experimental implementation under active |
| 30 | + development, and that the subset of Bean Machine model is limited. Limitations |
| 31 | + include that the runtime graph should be static (meaning, it does not change |
| 32 | + during inference), and that the types of primitive distributions supported |
| 33 | + is currently limited. |
| 34 | + """ |
| 35 | + |
| 36 | + _fix_observe_true: bool = False |
| 37 | + _infer_config = {} |
| 38 | + |
| 39 | + def __init__(self): |
| 40 | + pass |
| 41 | + |
| 42 | + def _accumulate_graph( |
| 43 | + self, |
| 44 | + queries: List[RVIdentifier], |
| 45 | + observations: Dict[RVIdentifier, torch.Tensor], |
| 46 | + ) -> BMGRuntime: |
| 47 | + _verify_queries_and_observations(queries, observations, True) |
| 48 | + rt = BMGRuntime() |
| 49 | + bmg = rt.accumulate_graph(queries, observations) |
| 50 | + # TODO: Figure out a better way to pass this flag around |
| 51 | + bmg._fix_observe_true = self._fix_observe_true |
| 52 | + return rt |
| 53 | + |
| 54 | + def _build_mcsamples( |
| 55 | + self, |
| 56 | + queries, |
| 57 | + opt_rv_to_query_map, |
| 58 | + samples, |
| 59 | + ) -> MonteCarloSamples: |
| 60 | + assert len(samples) == len(queries) |
| 61 | + |
| 62 | + results: Dict[RVIdentifier, torch.Tensor] = {} |
| 63 | + for rv in samples.keys(): |
| 64 | + query = opt_rv_to_query_map[rv.__str__()] |
| 65 | + results[query] = samples[rv] |
| 66 | + mcsamples = MonteCarloSamples(results) |
| 67 | + return mcsamples |
| 68 | + |
| 69 | + def _infer( |
| 70 | + self, |
| 71 | + queries: List[RVIdentifier], |
| 72 | + observations: Dict[RVIdentifier, torch.Tensor], |
| 73 | + num_samples: int, |
| 74 | + num_chains: int = 1, |
| 75 | + num_adaptive_samples: int = 0, |
| 76 | + inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler, |
| 77 | + skip_optimizations: Set[str] = default_skip_optimizations, |
| 78 | + ) -> MonteCarloSamples: |
| 79 | + |
| 80 | + rt = self._accumulate_graph(queries, observations) |
| 81 | + bmg = rt._bmg |
| 82 | + |
| 83 | + self._infer_config["num_samples"] = num_samples |
| 84 | + self._infer_config["num_chains"] = num_chains |
| 85 | + self._infer_config["num_adaptive_samples"] = num_adaptive_samples |
| 86 | + |
| 87 | + generated_graph = to_bmg_graph(bmg, skip_optimizations) |
| 88 | + optimized_python, opt_rv_to_query_map = to_bm_python( |
| 89 | + generated_graph.bmg, inference_type, self._infer_config |
| 90 | + ) |
| 91 | + |
| 92 | + try: |
| 93 | + exec(optimized_python, globals()) # noqa |
| 94 | + except RuntimeError as e: |
| 95 | + raise RuntimeError("Error during BM inference\n") from e |
| 96 | + |
| 97 | + opt_samples = self._build_mcsamples( |
| 98 | + queries, |
| 99 | + opt_rv_to_query_map, |
| 100 | + # pyre-ignore |
| 101 | + samples, # noqa |
| 102 | + ) |
| 103 | + return opt_samples |
| 104 | + |
| 105 | + def infer( |
| 106 | + self, |
| 107 | + queries: List[RVIdentifier], |
| 108 | + observations: Dict[RVIdentifier, torch.Tensor], |
| 109 | + num_samples: int, |
| 110 | + num_chains: int = 4, |
| 111 | + num_adaptive_samples: int = 0, |
| 112 | + inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler, |
| 113 | + skip_optimizations: Set[str] = default_skip_optimizations, |
| 114 | + ) -> MonteCarloSamples: |
| 115 | + """ |
| 116 | + Perform inference by (runtime) compilation of Python source code associated |
| 117 | + with its parameters, constructing an optimized BM graph, and then calling the |
| 118 | + BM implementation of a particular inference method on this graph. |
| 119 | +
|
| 120 | + Args: |
| 121 | + queries: queried random variables |
| 122 | + observations: observations dict |
| 123 | + num_samples: number of samples in each chain |
| 124 | + num_chains: number of chains generated |
| 125 | + num_adaptive_samples: number of burn in samples to discard |
| 126 | + inference_type: inference method |
| 127 | + skip_optimizations: list of optimization to disable in this call |
| 128 | +
|
| 129 | + Returns: |
| 130 | + MonteCarloSamples: The requested samples |
| 131 | + """ |
| 132 | + # TODO: Add verbose level |
| 133 | + # TODO: Add logging |
| 134 | + samples = self._infer( |
| 135 | + queries, |
| 136 | + observations, |
| 137 | + num_samples, |
| 138 | + num_chains, |
| 139 | + num_adaptive_samples, |
| 140 | + inference_type, |
| 141 | + skip_optimizations, |
| 142 | + ) |
| 143 | + return samples |
| 144 | + |
| 145 | + def to_dot( |
| 146 | + self, |
| 147 | + queries: List[RVIdentifier], |
| 148 | + observations: Dict[RVIdentifier, torch.Tensor], |
| 149 | + after_transform: bool = True, |
| 150 | + label_edges: bool = False, |
| 151 | + skip_optimizations: Set[str] = default_skip_optimizations, |
| 152 | + ) -> str: |
| 153 | + """Produce a string containing a program in the GraphViz DOT language |
| 154 | + representing the graph deduced from the model.""" |
| 155 | + node_types = False |
| 156 | + node_sizes = False |
| 157 | + edge_requirements = False |
| 158 | + bmg = self._accumulate_graph(queries, observations)._bmg |
| 159 | + return to_dot( |
| 160 | + bmg, |
| 161 | + node_types, |
| 162 | + node_sizes, |
| 163 | + edge_requirements, |
| 164 | + after_transform, |
| 165 | + label_edges, |
| 166 | + skip_optimizations, |
| 167 | + ) |
| 168 | + |
| 169 | + def _to_mini( |
| 170 | + self, |
| 171 | + queries: List[RVIdentifier], |
| 172 | + observations: Dict[RVIdentifier, torch.Tensor], |
| 173 | + indent=None, |
| 174 | + ) -> str: |
| 175 | + """Internal test method for Neal's MiniBMG prototype.""" |
| 176 | + bmg = self._accumulate_graph(queries, observations)._bmg |
| 177 | + return to_mini(bmg, indent) |
| 178 | + |
| 179 | + def to_graphviz( |
| 180 | + self, |
| 181 | + queries: List[RVIdentifier], |
| 182 | + observations: Dict[RVIdentifier, torch.Tensor], |
| 183 | + after_transform: bool = True, |
| 184 | + label_edges: bool = False, |
| 185 | + skip_optimizations: Set[str] = default_skip_optimizations, |
| 186 | + ) -> graphviz.Source: |
| 187 | + """Small wrapper to generate an actual graphviz object""" |
| 188 | + s = self.to_dot( |
| 189 | + queries, observations, after_transform, label_edges, skip_optimizations |
| 190 | + ) |
| 191 | + return graphviz.Source(s) |
| 192 | + |
| 193 | + def to_python( |
| 194 | + self, |
| 195 | + queries: List[RVIdentifier], |
| 196 | + observations: Dict[RVIdentifier, torch.Tensor], |
| 197 | + num_samples: int, |
| 198 | + num_chains: int = 4, |
| 199 | + num_adaptive_samples: int = 0, |
| 200 | + inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler, |
| 201 | + skip_optimizations: Set[str] = default_skip_optimizations, |
| 202 | + ) -> str: |
| 203 | + """Produce a string containing a BM Python program from the graph.""" |
| 204 | + bmg = self._accumulate_graph(queries, observations)._bmg |
| 205 | + self._infer_config["num_samples"] = num_samples |
| 206 | + self._infer_config["num_chains"] = num_chains |
| 207 | + self._infer_config["num_adaptive_samples"] = num_adaptive_samples |
| 208 | + opt_bm, _ = to_bm_python(bmg, inference_type, self._infer_config) |
| 209 | + return opt_bm |
0 commit comments