Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
This repository was archived by the owner on Dec 18, 2023. It is now read-only.

BMGInference does not handle torch.stack called with other RVs #1565

@feynmanliang

Description

@feynmanliang

Issue Description

When other RVs are concatenated together using torch.stack, BMGInference fails to
trace execution because it assumes that all arguments to stack are of type Tensor.

The example runs fine if stack is replaced by torch.tensor, but torch.tensor is not differentiable wrt its arguments which precludes methods such as VI and HMC.

Steps to Reproduce

import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference

foo = bm.random_variable(lambda: dist.Normal(torch.stack([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
    queries=[foo()],
    observations={},
    num_samples=1,
)

raises

expected Tensor as element 0 in argument 0, but got SampleNode
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-358-e65cd9e99a94> in <module>
      4 foo = bm.random_variable(lambda: dist.MultivariateNormal(torch.stack([bar(i) for i in range(2)]), torch.eye(2)))
      5 bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
----> 6 BMGInference().infer(
      7     queries=[foo()],
      8     observations={},
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in infer(self, queries, observations, num_samples, num_chains, inference_type, skip_optimizations)
    262         # TODO: Add verbose level
    263         # TODO: Add logging
--> 264         samples, _ = self._infer(
    265             queries,
    266             observations,
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _infer(self, queries, observations, num_samples, num_chains, inference_type, produce_report, skip_optimizations)
    182             self._pd = prof.ProfilerData()
    183 
--> 184         rt = self._accumulate_graph(queries, observations)
    185         bmg = rt._bmg
    186         report = pr.PerformanceReport()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _accumulate_graph(self, queries, observations)
     71         rt = BMGRuntime()
     72         rt._pd = self._pd
---> 73         bmg = rt.accumulate_graph(queries, observations)
     74         # TODO: Figure out a better way to pass this flag around
     75         bmg._fix_observe_true = self._fix_observe_true
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in accumulate_graph(self, queries, observations)
    719             self._bmg.add_observation(node, val)
    720         for qrv in queries:
--> 721             node = self._rv_to_node(qrv)
    722             q = self._bmg.add_query(node)
    723             self._rv_to_query[qrv] = q
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in _rv_to_node(self, rv)
    583                 # RVID, and if we're in the second situation, we will not.
    584 
--> 585                 value = self._context.call(rewritten_function, rv.arguments)
    586                 if isinstance(value, RVIdentifier):
    587                     # We have a rewritten function with a decorator already applied.
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/execution_context.py in call(self, func, args, kwargs)
     92         self._stack.push(FunctionCall(func, args, kwargs))
     93         try:
---> 94             return func(*args, **kwargs)
     95         finally:
     96             self._stack.pop()
<BMGJIT> in a1()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in handle_function(self, function, arguments, kwargs)
    510             function, arguments, kwargs
    511         ):
--> 512             result = self._special_function_caller.do_special_call_maybe_stochastic(
    513                 function, arguments, kwargs
    514             )
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/special_function_caller.py in do_special_call_maybe_stochastic(self, func, args, kwargs)
    629             new_args = (_get_ordinary_value(arg) for arg in args)
    630             new_kwargs = {key: _get_ordinary_value(arg) for key, arg in kwargs.items()}
--> 631             return func(*new_args, **new_kwargs)
    632 
    633         if _is_in_place_operator(func):
TypeError: expected Tensor as element 0 in argument 0, but got SampleNode

Expected Behavior

Successful execution with identical results to s/stack/tensor i.e.

import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference

foo = bm.random_variable(lambda: dist.Normal(torch.tensor([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
    queries=[foo()],
    observations={},
    num_samples=1,
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions