This repository was archived by the owner on Dec 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 51
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
BMGInference does not handle torch.stack called with other RVs #1565
Copy link
Copy link
Open
Description
Issue Description
When other RVs are concatenated together using torch.stack
, BMGInference
fails to
trace execution because it assumes that all arguments to stack
are of type Tensor
.
The example runs fine if stack
is replaced by torch.tensor
, but torch.tensor
is not differentiable wrt its arguments which precludes methods such as VI and HMC.
Steps to Reproduce
import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference
foo = bm.random_variable(lambda: dist.Normal(torch.stack([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
queries=[foo()],
observations={},
num_samples=1,
)
raises
expected Tensor as element 0 in argument 0, but got SampleNode
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-358-e65cd9e99a94> in <module>
4 foo = bm.random_variable(lambda: dist.MultivariateNormal(torch.stack([bar(i) for i in range(2)]), torch.eye(2)))
5 bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
----> 6 BMGInference().infer(
7 queries=[foo()],
8 observations={},
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in infer(self, queries, observations, num_samples, num_chains, inference_type, skip_optimizations)
262 # TODO: Add verbose level
263 # TODO: Add logging
--> 264 samples, _ = self._infer(
265 queries,
266 observations,
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _infer(self, queries, observations, num_samples, num_chains, inference_type, produce_report, skip_optimizations)
182 self._pd = prof.ProfilerData()
183
--> 184 rt = self._accumulate_graph(queries, observations)
185 bmg = rt._bmg
186 report = pr.PerformanceReport()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _accumulate_graph(self, queries, observations)
71 rt = BMGRuntime()
72 rt._pd = self._pd
---> 73 bmg = rt.accumulate_graph(queries, observations)
74 # TODO: Figure out a better way to pass this flag around
75 bmg._fix_observe_true = self._fix_observe_true
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in accumulate_graph(self, queries, observations)
719 self._bmg.add_observation(node, val)
720 for qrv in queries:
--> 721 node = self._rv_to_node(qrv)
722 q = self._bmg.add_query(node)
723 self._rv_to_query[qrv] = q
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in _rv_to_node(self, rv)
583 # RVID, and if we're in the second situation, we will not.
584
--> 585 value = self._context.call(rewritten_function, rv.arguments)
586 if isinstance(value, RVIdentifier):
587 # We have a rewritten function with a decorator already applied.
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/execution_context.py in call(self, func, args, kwargs)
92 self._stack.push(FunctionCall(func, args, kwargs))
93 try:
---> 94 return func(*args, **kwargs)
95 finally:
96 self._stack.pop()
<BMGJIT> in a1()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in handle_function(self, function, arguments, kwargs)
510 function, arguments, kwargs
511 ):
--> 512 result = self._special_function_caller.do_special_call_maybe_stochastic(
513 function, arguments, kwargs
514 )
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/special_function_caller.py in do_special_call_maybe_stochastic(self, func, args, kwargs)
629 new_args = (_get_ordinary_value(arg) for arg in args)
630 new_kwargs = {key: _get_ordinary_value(arg) for key, arg in kwargs.items()}
--> 631 return func(*new_args, **new_kwargs)
632
633 if _is_in_place_operator(func):
TypeError: expected Tensor as element 0 in argument 0, but got SampleNode
Expected Behavior
Successful execution with identical results to s/stack/tensor
i.e.
import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference
foo = bm.random_variable(lambda: dist.Normal(torch.tensor([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
queries=[foo()],
observations={},
num_samples=1,
)
Metadata
Metadata
Assignees
Labels
No labels