Skip to content

Commit 2e2346a

Browse files
committed
temporarily adding AutoNormalMessenger with quantiles from pyro pyro-ppl/pyro#2988
1 parent 7027862 commit 2e2346a

File tree

1 file changed

+78
-1
lines changed

1 file changed

+78
-1
lines changed

cell2location/models/base/_pyro_mixin.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,100 @@
11
from datetime import date
22
from functools import partial
3+
from typing import Callable, Tuple, Union
34

45
import matplotlib
56
import matplotlib.pyplot as plt
67
import numpy as np
78
import pandas as pd
89
import pyro
10+
import pyro.distributions as dist
911
import torch
1012
from pyro import poutine
11-
from pyro.infer.autoguide import AutoNormal, AutoNormalMessenger, init_to_mean
13+
from pyro.distributions.distribution import Distribution
14+
from pyro.infer.autoguide import AutoNormal
15+
from pyro.infer.autoguide import AutoNormalMessenger as AutoNormalMessengerPyro
16+
from pyro.infer.autoguide import init_to_feasible, init_to_mean
17+
from pyro.infer.autoguide.utils import helpful_support_errors
1218
from scipy.sparse import issparse
1319
from scvi import _CONSTANTS
1420
from scvi.data._anndata import get_from_registry
1521
from scvi.dataloaders import AnnDataLoader
1622
from scvi.model._utils import parse_use_gpu_arg
23+
from torch.distributions import biject_to
1724

1825
from ...distributions.AutoNormalEncoder import AutoGuideList, AutoNormalEncoder
1926

2027

28+
class AutoNormalMessenger(AutoNormalMessengerPyro):
29+
"""
30+
:class:`AutoMessenger` with mean-field normal posterior.
31+
32+
Copied from Pyro with modifications adding quantile methods.
33+
34+
The mean-field posterior at any site is a transformed normal distribution.
35+
This posterior is equivalent to :class:`~pyro.infer.autoguide.AutoNormal`
36+
or :class:`~pyro.infer.autoguide.AutoDiagonalNormal`, but allows
37+
customization via subclassing.
38+
39+
:param callable model: A Pyro model.
40+
:param callable init_loc_fn: A per-site initialization function.
41+
See :ref:`autoguide-initialization` section for available functions.
42+
:param float init_scale: Initial scale for the standard deviation of each
43+
(unconstrained transformed) latent variable.
44+
:param tuple amortized_plates: A tuple of names of plates over which guide
45+
parameters should be shared. This is useful for subsampling, where a
46+
guide parameter can be shared across all plates.
47+
"""
48+
49+
def __init__(
50+
self,
51+
model: Callable,
52+
*,
53+
init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible),
54+
init_scale: float = 0.1,
55+
amortized_plates: Tuple[str, ...] = (),
56+
):
57+
if not isinstance(init_scale, float) or not (init_scale > 0):
58+
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
59+
super().__init__(model, amortized_plates=amortized_plates)
60+
self.init_loc_fn = init_loc_fn
61+
self._init_scale = init_scale
62+
self._computing_median = False
63+
self._computing_quantiles = False
64+
self._quantile_values = None
65+
66+
def get_posterior(self, name: str, prior: Distribution) -> Union[Distribution, torch.Tensor]:
67+
if self._computing_median:
68+
return self._get_posterior_median(name, prior)
69+
if self._computing_quantiles:
70+
return self._get_posterior_quantiles(name, prior)
71+
72+
with helpful_support_errors({"name": name, "fn": prior}):
73+
transform = biject_to(prior.support)
74+
loc, scale = self._get_params(name, prior)
75+
posterior = dist.TransformedDistribution(
76+
dist.Normal(loc, scale).to_event(transform.domain.event_dim),
77+
transform.with_cache(),
78+
)
79+
return posterior
80+
81+
def quantiles(self, quantiles, *args, **kwargs):
82+
self._computing_quantiles = True
83+
self._quantile_values = quantiles
84+
try:
85+
return self(*args, **kwargs)
86+
finally:
87+
self._computing_quantiles = False
88+
89+
@torch.no_grad()
90+
def _get_posterior_quantiles(self, name, prior):
91+
transform = biject_to(prior.support)
92+
loc, scale = self._get_params(name, prior)
93+
site_quantiles = torch.tensor(self._quantile_values, dtype=loc.dtype, device=loc.device)
94+
site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles)
95+
return transform(site_quantiles_values)
96+
97+
2198
def init_to_value(site=None, values={}):
2299
if site is None:
23100
return partial(init_to_value, values=values)

0 commit comments

Comments
 (0)