Skip to content

Commit 666ac8c

Browse files
Add DiscreteMarkovChain distribution (#100)
1 parent 0a0f544 commit 666ac8c

File tree

5 files changed

+1881
-3
lines changed

5 files changed

+1881
-3
lines changed

docs/api_reference.rst

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Distributions
3030

3131
GenExtreme
3232
histogram_utils.histogram_approximation
33+
DiscreteMarkovChain
3334

3435

3536
Gaussian Processess

notebooks/discrete_markov_chain.ipynb

+1,403
Large diffs are not rendered by default.

pymc_experimental/distributions/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
"""
1919

2020
from pymc_experimental.distributions.continuous import GenExtreme
21+
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
2122

22-
__all__ = [
23-
"GenExtreme",
24-
]
23+
__all__ = ["GenExtreme", "DiscreteMarkovChain"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import warnings
2+
from typing import List, Union
3+
4+
import numpy as np
5+
import pymc as pm
6+
import pytensor
7+
import pytensor.tensor as pt
8+
from pymc.distributions.dist_math import check_parameters
9+
from pymc.distributions.distribution import (
10+
Distribution,
11+
SymbolicRandomVariable,
12+
_moment,
13+
moment,
14+
)
15+
from pymc.distributions.shape_utils import (
16+
_change_dist_size,
17+
change_dist_size,
18+
get_support_shape_1d,
19+
)
20+
from pymc.logprob.abstract import _logprob
21+
from pymc.logprob.basic import logp
22+
from pymc.logprob.utils import ignore_logprob
23+
from pymc.pytensorf import intX
24+
from pymc.util import check_dist_not_registered
25+
from pytensor.graph.basic import Node
26+
from pytensor.tensor import TensorVariable
27+
from pytensor.tensor.random.op import RandomVariable
28+
29+
30+
def _make_outputs_info(n_lags: int, init_dist: Distribution) -> List[Union[Distribution, dict]]:
31+
"""
32+
Two cases are needed for outputs_info in the scans used by DiscreteMarkovRv. If n_lags = 1, we need to throw away
33+
the first dimension of init_dist_ or else markov_chain will have shape (steps, 1, *batch_size) instead of
34+
desired (steps, *batch_size)
35+
36+
Parameters
37+
----------
38+
n_lags: int
39+
Number of lags the Markov Chain considers when transitioning to the next state
40+
init_dist: RandomVariable
41+
Distribution over initial states
42+
43+
Returns
44+
-------
45+
taps: list
46+
Lags to be fed into pytensor.scan when drawing a markov chain
47+
"""
48+
49+
if n_lags > 1:
50+
return [{"initial": init_dist, "taps": list(range(-n_lags, 0))}]
51+
else:
52+
return [init_dist[0]]
53+
54+
55+
class DiscreteMarkovChainRV(SymbolicRandomVariable):
56+
n_lags: int
57+
default_output = 1
58+
_print_name = ("DiscreteMC", "\\operatorname{DiscreteMC}")
59+
60+
def __init__(self, *args, n_lags, **kwargs):
61+
self.n_lags = n_lags
62+
super().__init__(*args, **kwargs)
63+
64+
def update(self, node: Node):
65+
return {node.inputs[-1]: node.outputs[0]}
66+
67+
68+
class DiscreteMarkovChain(Distribution):
69+
r"""
70+
A Discrete Markov Chain is a sequence of random variables
71+
72+
.. math::
73+
74+
\{x_t\}_{t=0}^T
75+
76+
Where transition probability :math:`P(x_t | x_{t-1})` depends only on the state of the system at :math:`x_{t-1}`.
77+
78+
Parameters
79+
----------
80+
P: tensor
81+
Matrix of transition probabilities between states. Rows must sum to 1.
82+
One of P or P_logits must be provided.
83+
P_logit: tensor, optional
84+
Matrix of transition logits. Converted to probabilities via Softmax activation.
85+
One of P or P_logits must be provided.
86+
steps: tensor, optional
87+
Length of the markov chain. Only needed if state is not provided.
88+
init_dist : unnamed distribution, optional
89+
Vector distribution for initial values. Unnamed refers to distributions
90+
created with the ``.dist()`` API. Distribution should have shape n_states.
91+
If not, it will be automatically resized.
92+
93+
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
94+
95+
Notes
96+
-----
97+
The initial distribution will be cloned, rendering it distinct from the one passed as
98+
input.
99+
100+
Examples
101+
--------
102+
Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
103+
3 in this case.
104+
105+
>>> with pm.Model() as markov_chain:
106+
>>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
107+
>>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
108+
>>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
109+
110+
"""
111+
112+
rv_type = DiscreteMarkovChainRV
113+
114+
def __new__(cls, *args, steps=None, n_lags=1, **kwargs):
115+
steps = get_support_shape_1d(
116+
support_shape=steps,
117+
shape=None,
118+
dims=kwargs.get("dims", None),
119+
observed=kwargs.get("observed", None),
120+
support_shape_offset=n_lags,
121+
)
122+
123+
return super().__new__(cls, *args, steps=steps, n_lags=n_lags, **kwargs)
124+
125+
@classmethod
126+
def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwargs):
127+
steps = get_support_shape_1d(
128+
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=n_lags
129+
)
130+
131+
if steps is None:
132+
raise ValueError("Must specify steps or shape parameter")
133+
if P is None and logit_P is None:
134+
raise ValueError("Must specify P or logit_P parameter")
135+
if P is not None and logit_P is not None:
136+
raise ValueError("Must specify only one of either P or logit_P parameter")
137+
138+
if logit_P is not None:
139+
P = pm.math.softmax(logit_P, axis=-1)
140+
141+
P = pt.as_tensor_variable(P)
142+
steps = pt.as_tensor_variable(intX(steps))
143+
144+
if init_dist is not None:
145+
if not isinstance(init_dist, TensorVariable) or not isinstance(
146+
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
147+
):
148+
raise ValueError(
149+
f"Init dist must be a distribution created via the `.dist()` API, "
150+
f"got {type(init_dist)}"
151+
)
152+
153+
check_dist_not_registered(init_dist)
154+
if init_dist.owner.op.ndim_supp > 1:
155+
raise ValueError(
156+
"Init distribution must have a scalar or vector support dimension, ",
157+
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
158+
)
159+
else:
160+
warnings.warn(
161+
"Initial distribution not specified, defaulting to "
162+
"`Categorical.dist(p=pt.full((k_states, ), 1/k_states), shape=...)`. You can specify an init_dist "
163+
"manually to suppress this warning.",
164+
UserWarning,
165+
)
166+
k = P.shape[-1]
167+
init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k))
168+
169+
# We can ignore init_dist, as it will be accounted for in the logp term
170+
init_dist = ignore_logprob(init_dist)
171+
172+
return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs)
173+
174+
@classmethod
175+
def rv_op(cls, P, steps, init_dist, n_lags, size=None):
176+
if size is not None:
177+
batch_size = size
178+
else:
179+
batch_size = pt.broadcast_shape(
180+
P[tuple([...] + [0] * (n_lags + 1))], pt.atleast_1d(init_dist)[..., 0]
181+
)
182+
183+
init_dist = change_dist_size(init_dist, (n_lags, *batch_size))
184+
init_dist_ = init_dist.type()
185+
P_ = P.type()
186+
steps_ = steps.type()
187+
188+
state_rng = pytensor.shared(np.random.default_rng())
189+
190+
def transition(*args):
191+
*states, transition_probs, old_rng = args
192+
p = transition_probs[tuple(states)]
193+
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
194+
return next_state, {old_rng: next_rng}
195+
196+
markov_chain, state_updates = pytensor.scan(
197+
transition,
198+
non_sequences=[P_, state_rng],
199+
outputs_info=_make_outputs_info(n_lags, init_dist_),
200+
n_steps=steps_,
201+
strict=True,
202+
)
203+
204+
(state_next_rng,) = tuple(state_updates.values())
205+
206+
discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
207+
208+
discrete_mc_op = DiscreteMarkovChainRV(
209+
inputs=[P_, steps_, init_dist_],
210+
outputs=[state_next_rng, discrete_mc_],
211+
ndim_supp=1,
212+
n_lags=n_lags,
213+
)
214+
215+
discrete_mc = discrete_mc_op(P, steps, init_dist)
216+
return discrete_mc
217+
218+
219+
@_change_dist_size.register(DiscreteMarkovChainRV)
220+
def change_mc_size(op, dist, new_size, expand=False):
221+
if expand:
222+
old_size = dist.shape[:-1]
223+
new_size = tuple(new_size) + tuple(old_size)
224+
225+
return DiscreteMarkovChain.rv_op(*dist.owner.inputs[:-1], size=new_size, n_lags=op.n_lags)
226+
227+
228+
@_moment.register(DiscreteMarkovChainRV)
229+
def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
230+
init_dist_moment = moment(init_dist)
231+
n_lags = op.n_lags
232+
233+
def greedy_transition(*args):
234+
*states, transition_probs, old_rng = args
235+
p = transition_probs[tuple(states)]
236+
return pt.argmax(p)
237+
238+
chain_moment, moment_updates = pytensor.scan(
239+
greedy_transition,
240+
non_sequences=[P, state_rng],
241+
outputs_info=_make_outputs_info(n_lags, init_dist),
242+
n_steps=steps,
243+
strict=True,
244+
)
245+
chain_moment = pt.concatenate([init_dist_moment, chain_moment])
246+
return chain_moment
247+
248+
249+
@_logprob.register(DiscreteMarkovChainRV)
250+
def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
251+
value = values[0]
252+
n_lags = op.n_lags
253+
254+
indexes = [value[..., i : -(n_lags - i) if n_lags != i else None] for i in range(n_lags + 1)]
255+
256+
mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1)
257+
mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1)
258+
259+
return check_parameters(
260+
mc_logprob,
261+
pt.all(pt.eq(P.shape[-(n_lags + 1) :], P.shape[-1])),
262+
pt.all(pt.allclose(P.sum(axis=-1), 1.0)),
263+
pt.eq(pt.atleast_1d(init_dist).shape[0], n_lags),
264+
msg="Last (n_lags + 1) dimensions of P must be square, "
265+
"P must sum to 1 along the last axis, "
266+
"First dimension of init_dist must be n_lags",
267+
)

0 commit comments

Comments
 (0)