Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.

Commit 86990bc

Browse files
BMInference API to compile and run BM models from Beanstalk (#1801)
Summary: Pull Request resolved: #1801 We currently have `BMGInference`, that takes a BM python program produces a valid BMG graph and runs inference on the BMG C++ backend. In the same vein, this diff implements a `BMInference` API that takes a BM python program, optimizes it (currently uses the same optimization path as BMG) and runs BM inference methods to compute the samples. This allows us to separate the type of optimizations each backend needs and allows us to test end-to-end with an inference method. Differential Revision: D40853055 fbshipit-source-id: 8b5ce58385c5fe17e0a688e428261f3276e9b03b
1 parent 2880349 commit 86990bc

File tree

4 files changed

+308
-33
lines changed

4 files changed

+308
-33
lines changed

src/beanmachine/ppl/compiler/gen_bm_python.py

+44-9
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,23 @@
55

66

77
from collections import defaultdict
8-
from typing import Dict, List
8+
from enum import Enum
9+
from typing import Dict, List, Tuple
910

1011
from beanmachine.ppl.compiler import bmg_nodes as bn
1112

1213
from beanmachine.ppl.compiler.bm_graph_builder import BMGraphBuilder
1314
from beanmachine.ppl.compiler.fix_problems import fix_problems
1415
from beanmachine.ppl.compiler.internal_error import InternalError
16+
from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler
17+
from beanmachine.ppl.inference.single_site_nmc import SingleSiteNewtonianMonteCarlo
18+
from beanmachine.ppl.model.rv_identifier import RVIdentifier
19+
20+
21+
class InferenceType(Enum):
22+
SingleSiteNewtonianMonteCarlo = SingleSiteNewtonianMonteCarlo
23+
GlobalNoUTurnSampler = GlobalNoUTurnSampler
24+
1525

1626
_node_type_to_distribution = {
1727
bn.BernoulliNode: "torch.distributions.Bernoulli",
@@ -37,6 +47,8 @@ class ToBMPython:
3747
no_dist_samples: Dict[bn.BMGNode, int]
3848
queries: List[str]
3949
observations: List[str]
50+
node_to_rv_id: Dict[str, str]
51+
node_to_query_map: Dict[str, RVIdentifier]
4052

4153
def __init__(self, bmg: BMGraphBuilder) -> None:
4254
self.code = ""
@@ -51,6 +63,8 @@ def __init__(self, bmg: BMGraphBuilder) -> None:
5163
self.no_dist_samples = defaultdict(lambda: 0)
5264
self.queries = []
5365
self.observations = []
66+
self.node_to_rv_id = {}
67+
self.node_to_query_map = {}
5468

5569
def _get_node_id_mapping(self, node: bn.BMGNode) -> str:
5670
if node in self.node_to_var_id:
@@ -132,12 +146,17 @@ def _add_sample(self, node: bn.SampleNode) -> None:
132146
total_samples = self._no_dist_samples(node.operand)
133147
if total_samples > 1:
134148
param = f"{self.no_dist_samples[node.operand]}"
149+
self._code.append(f"v{var_id} = rv{rv_id}({param},)")
150+
self.node_to_rv_id[f"v{var_id}"] = f"rv{rv_id}({param},)"
135151
else:
136152
param = ""
137-
self._code.append(f"v{var_id} = rv{rv_id}({param})")
153+
self._code.append(f"v{var_id} = rv{rv_id}({param})")
154+
self.node_to_rv_id[f"v{var_id}"] = f"rv{rv_id}({param})"
138155

139156
def _add_query(self, node: bn.Query) -> None:
140-
self.queries.append(f"{self._get_node_id_mapping(node.operator)}")
157+
query_id = self._get_node_id_mapping(node.operator)
158+
self.node_to_query_map[self.node_to_rv_id[query_id]] = node.rv_identifier
159+
self.queries.append(f"{query_id}")
141160

142161
def _add_observation(self, node: bn.Observation) -> None:
143162
val = node.value
@@ -163,18 +182,34 @@ def _generate_python(self, node: bn.BMGNode) -> None:
163182
elif isinstance(node, bn.Observation):
164183
self._add_observation(node)
165184

166-
def _generate_bm_python(self) -> str:
185+
def _generate_bm_python(
186+
self, inference_type, infer_config
187+
) -> Tuple[str, Dict[str, RVIdentifier]]:
167188
bmg, error_report = fix_problems(self.bmg)
168189
self.bmg = bmg
169190
error_report.raise_errors()
191+
self._code.append(
192+
f"from {inference_type.value.__module__} import {inference_type.value.__name__}"
193+
)
170194
for node in self.bmg.all_ancestor_nodes():
171195
self._generate_python(node)
172-
self._code.append(f"queries = [{(','.join(self.queries))}]")
173-
self._code.append(f"observations = {{{','.join(self.observations)}}}")
196+
self._code.append(f"opt_queries = [{(','.join(self.queries))}]")
197+
self._code.append(f"opt_observations = {{{','.join(self.observations)}}}")
198+
self._code.append(
199+
f"""samples = {inference_type.value.__name__}().infer(
200+
opt_queries,
201+
opt_observations,
202+
num_samples={infer_config['num_samples']},
203+
num_chains={infer_config['num_chains']},
204+
num_adaptive_samples={infer_config['num_adaptive_samples']}
205+
)"""
206+
)
174207
self.code = "\n".join(self._code)
175-
return self.code
208+
return self.code, self.node_to_query_map
176209

177210

178-
def to_bm_python(bmg: BMGraphBuilder) -> str:
211+
def to_bm_python(
212+
bmg: BMGraphBuilder, inference_type: InferenceType, infer_config: Dict
213+
) -> Tuple[str, Dict[str, RVIdentifier]]:
179214
bmp = ToBMPython(bmg)
180-
return bmp._generate_bm_python()
215+
return bmp._generate_bm_python(inference_type, infer_config)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""An inference engine which uses Bean Machine to make
7+
inferences on optimized Bean Machine models."""
8+
9+
from typing import Dict, List, Set
10+
11+
import graphviz
12+
import torch
13+
14+
from beanmachine.ppl.compiler.fix_problems import default_skip_optimizations
15+
from beanmachine.ppl.compiler.gen_bm_python import InferenceType, to_bm_python
16+
from beanmachine.ppl.compiler.gen_bmg_graph import to_bmg_graph
17+
from beanmachine.ppl.compiler.gen_dot import to_dot
18+
from beanmachine.ppl.compiler.gen_mini import to_mini
19+
from beanmachine.ppl.compiler.runtime import BMGRuntime
20+
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples
21+
from beanmachine.ppl.inference.utils import _verify_queries_and_observations
22+
from beanmachine.ppl.model.rv_identifier import RVIdentifier
23+
24+
25+
class BMInference:
26+
"""
27+
Interface to Bean Machine Inference on optimized models.
28+
29+
Please note that this is a highly experimental implementation under active
30+
development, and that the subset of Bean Machine model is limited. Limitations
31+
include that the runtime graph should be static (meaning, it does not change
32+
during inference), and that the types of primitive distributions supported
33+
is currently limited.
34+
"""
35+
36+
_fix_observe_true: bool = False
37+
_infer_config = {}
38+
39+
def __init__(self):
40+
pass
41+
42+
def _accumulate_graph(
43+
self,
44+
queries: List[RVIdentifier],
45+
observations: Dict[RVIdentifier, torch.Tensor],
46+
) -> BMGRuntime:
47+
_verify_queries_and_observations(queries, observations, True)
48+
rt = BMGRuntime()
49+
bmg = rt.accumulate_graph(queries, observations)
50+
# TODO: Figure out a better way to pass this flag around
51+
bmg._fix_observe_true = self._fix_observe_true
52+
return rt
53+
54+
def _build_mcsamples(
55+
self,
56+
queries,
57+
opt_rv_to_query_map,
58+
samples,
59+
) -> MonteCarloSamples:
60+
assert len(samples) == len(queries)
61+
62+
results: Dict[RVIdentifier, torch.Tensor] = {}
63+
for rv in samples.keys():
64+
query = opt_rv_to_query_map[rv.__str__()]
65+
results[query] = samples[rv]
66+
mcsamples = MonteCarloSamples(results)
67+
return mcsamples
68+
69+
def _infer(
70+
self,
71+
queries: List[RVIdentifier],
72+
observations: Dict[RVIdentifier, torch.Tensor],
73+
num_samples: int,
74+
num_chains: int = 1,
75+
num_adaptive_samples: int = 0,
76+
inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler,
77+
skip_optimizations: Set[str] = default_skip_optimizations,
78+
) -> MonteCarloSamples:
79+
80+
rt = self._accumulate_graph(queries, observations)
81+
bmg = rt._bmg
82+
83+
self._infer_config["num_samples"] = num_samples
84+
self._infer_config["num_chains"] = num_chains
85+
self._infer_config["num_adaptive_samples"] = num_adaptive_samples
86+
87+
generated_graph = to_bmg_graph(bmg, skip_optimizations)
88+
optimized_python, opt_rv_to_query_map = to_bm_python(
89+
generated_graph.bmg, inference_type, self._infer_config
90+
)
91+
92+
try:
93+
exec(optimized_python, globals()) # noqa
94+
except RuntimeError as e:
95+
raise RuntimeError("Error during BM inference\n") from e
96+
97+
opt_samples = self._build_mcsamples(
98+
queries,
99+
opt_rv_to_query_map,
100+
# pyre-ignore
101+
samples, # noqa
102+
)
103+
return opt_samples
104+
105+
def infer(
106+
self,
107+
queries: List[RVIdentifier],
108+
observations: Dict[RVIdentifier, torch.Tensor],
109+
num_samples: int,
110+
num_chains: int = 4,
111+
num_adaptive_samples: int = 0,
112+
inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler,
113+
skip_optimizations: Set[str] = default_skip_optimizations,
114+
) -> MonteCarloSamples:
115+
"""
116+
Perform inference by (runtime) compilation of Python source code associated
117+
with its parameters, constructing an optimized BM graph, and then calling the
118+
BM implementation of a particular inference method on this graph.
119+
120+
Args:
121+
queries: queried random variables
122+
observations: observations dict
123+
num_samples: number of samples in each chain
124+
num_chains: number of chains generated
125+
num_adaptive_samples: number of burn in samples to discard
126+
inference_type: inference method
127+
skip_optimizations: list of optimization to disable in this call
128+
129+
Returns:
130+
MonteCarloSamples: The requested samples
131+
"""
132+
# TODO: Add verbose level
133+
# TODO: Add logging
134+
samples = self._infer(
135+
queries,
136+
observations,
137+
num_samples,
138+
num_chains,
139+
num_adaptive_samples,
140+
inference_type,
141+
skip_optimizations,
142+
)
143+
return samples
144+
145+
def to_dot(
146+
self,
147+
queries: List[RVIdentifier],
148+
observations: Dict[RVIdentifier, torch.Tensor],
149+
after_transform: bool = True,
150+
label_edges: bool = False,
151+
skip_optimizations: Set[str] = default_skip_optimizations,
152+
) -> str:
153+
"""Produce a string containing a program in the GraphViz DOT language
154+
representing the graph deduced from the model."""
155+
node_types = False
156+
node_sizes = False
157+
edge_requirements = False
158+
bmg = self._accumulate_graph(queries, observations)._bmg
159+
return to_dot(
160+
bmg,
161+
node_types,
162+
node_sizes,
163+
edge_requirements,
164+
after_transform,
165+
label_edges,
166+
skip_optimizations,
167+
)
168+
169+
def _to_mini(
170+
self,
171+
queries: List[RVIdentifier],
172+
observations: Dict[RVIdentifier, torch.Tensor],
173+
indent=None,
174+
) -> str:
175+
"""Internal test method for Neal's MiniBMG prototype."""
176+
bmg = self._accumulate_graph(queries, observations)._bmg
177+
return to_mini(bmg, indent)
178+
179+
def to_graphviz(
180+
self,
181+
queries: List[RVIdentifier],
182+
observations: Dict[RVIdentifier, torch.Tensor],
183+
after_transform: bool = True,
184+
label_edges: bool = False,
185+
skip_optimizations: Set[str] = default_skip_optimizations,
186+
) -> graphviz.Source:
187+
"""Small wrapper to generate an actual graphviz object"""
188+
s = self.to_dot(
189+
queries, observations, after_transform, label_edges, skip_optimizations
190+
)
191+
return graphviz.Source(s)
192+
193+
def to_python(
194+
self,
195+
queries: List[RVIdentifier],
196+
observations: Dict[RVIdentifier, torch.Tensor],
197+
num_samples: int,
198+
num_chains: int = 4,
199+
num_adaptive_samples: int = 0,
200+
inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler,
201+
skip_optimizations: Set[str] = default_skip_optimizations,
202+
) -> str:
203+
"""Produce a string containing a BM Python program from the graph."""
204+
bmg = self._accumulate_graph(queries, observations)._bmg
205+
self._infer_config["num_samples"] = num_samples
206+
self._infer_config["num_chains"] = num_chains
207+
self._infer_config["num_adaptive_samples"] = num_adaptive_samples
208+
opt_bm, _ = to_bm_python(bmg, inference_type, self._infer_config)
209+
return opt_bm

src/beanmachine/ppl/inference/bmg_inference.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from beanmachine.ppl.compiler.bm_graph_builder import rv_to_query
1818
from beanmachine.ppl.compiler.fix_problems import default_skip_optimizations
19-
from beanmachine.ppl.compiler.gen_bm_python import to_bm_python
2019
from beanmachine.ppl.compiler.gen_bmg_cpp import to_bmg_cpp
2120
from beanmachine.ppl.compiler.gen_bmg_graph import to_bmg_graph
2221
from beanmachine.ppl.compiler.gen_bmg_python import to_bmg_python
@@ -167,7 +166,7 @@ def _build_mcsamples(
167166
# but it requires the input to be of a different type in the
168167
# cases of num_chains==1 and !=1 respectively. Furthermore,
169168
# we had to tweak it to support the right operator for merging
170-
# saumple values when num_chains!=1.
169+
# sample values when num_chains!=1.
171170
if num_chains == 1:
172171
mcsamples = MonteCarloSamples(
173172
results[0], num_adaptive_samples, stack_not_cat=True
@@ -365,15 +364,6 @@ def to_python(
365364
bmg = self._accumulate_graph(queries, observations)._bmg
366365
return to_bmg_python(bmg).code
367366

368-
def to_bm_python(
369-
self,
370-
queries: List[RVIdentifier],
371-
observations: Dict[RVIdentifier, torch.Tensor],
372-
) -> str:
373-
"""Produce a string containing a BM Python program from the graph."""
374-
bmg = self._accumulate_graph(queries, observations)._bmg
375-
return to_bm_python(bmg)
376-
377367
def to_graph(
378368
self,
379369
queries: List[RVIdentifier],

0 commit comments

Comments
 (0)