Skip to content

Commit 9c65ca1

Browse files
committed
Allow truncation of hurdle distributions
1 parent 43646d6 commit 9c65ca1

File tree

4 files changed

+247
-13
lines changed

4 files changed

+247
-13
lines changed

notebooks/xmodel.ipynb

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "8b46af09bc772f64",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import numpy as np\n",
11+
"import pytensor.tensor as pt\n",
12+
"import pytensor.xtensor as px\n",
13+
"\n",
14+
"import pymc as pm"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"id": "eaca7be1e40a81c6",
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"class XModel(pm.Model):\n",
25+
" def register_rv(self, rv, *args, dims=None, **kwargs):\n",
26+
" rv = super().register_rv(rv, *args, dims=dims, **kwargs)\n",
27+
" if dims is not None:\n",
28+
" rv = px.as_xtensor(rv, dims=dims)\n",
29+
" return rv\n",
30+
"\n",
31+
" def add_named_variable(self, var, dims=None):\n",
32+
" if isinstance(var.type, px.type.XTensorType):\n",
33+
" if dims is None:\n",
34+
" dims = var.dims\n",
35+
" else:\n",
36+
" if dims != var.dims:\n",
37+
" raise ValueError(\n",
38+
" f\"Provided dims {dims} do not match variable pre-existing {var.dims}. \"\n",
39+
" \"Use rename and/or transpose to match new dims\"\n",
40+
" )\n",
41+
" super().add_named_variable(var, dims)\n",
42+
"\n",
43+
"\n",
44+
"def XData(name, x, *args, **kwargs):\n",
45+
" x = pm.Data(name, x, *args, **kwargs)\n",
46+
" model = pm.modelcontext(None)\n",
47+
" if (dims := model.named_vars_to_dims.get(x.name, None)) is not None:\n",
48+
" x = px.as_xtensor(x, dims=dims)\n",
49+
" return x"
50+
]
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": null,
55+
"id": "efeb5d5820e2efe7",
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"N = 100\n",
60+
"seed = sum(map(ord, \"xarray>=numpy?\"))\n",
61+
"rng = np.random.default_rng(seed)\n",
62+
"\n",
63+
"x_np = np.linspace(0, 10, N)\n",
64+
"y_np = np.piecewise(\n",
65+
" x_np,\n",
66+
" [x_np <= 3, (x_np > 3) & (x_np <= 7), x_np > 7],\n",
67+
" [lambda x: 0.5 * x, lambda x: 1.5 + 0.2 * (x - 3), lambda x: 2.3 - 0.1 * (x - 7)],\n",
68+
")\n",
69+
"y_np += rng.normal(0, 0.2, size=N)\n",
70+
"group_idx = rng.choice(3, size=N)\n",
71+
"\n",
72+
"N_knots = 13\n",
73+
"knots_np = np.linspace(0, 10, num=N_knots)"
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"id": "6f5476abb800b402",
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"coords = {\n",
84+
" \"group\": range(3),\n",
85+
" \"knots\": range(N_knots),\n",
86+
" \"obs\": range(N),\n",
87+
"}"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "ca734923d4d51c4c",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"with pm.Model(coords=coords) as model:\n",
98+
" x = pm.Data(\"x\", x_np, dims=\"obs\")\n",
99+
" knots = pm.Data(\"knots\", knots_np, dims=\"knot\")\n",
100+
"\n",
101+
" sigma = pm.HalfCauchy(\"sigma\", beta=1)\n",
102+
" sigma_beta0 = pm.HalfNormal(\"sigma_beta0\", sigma=10)\n",
103+
" beta0 = pm.HalfNormal(\"beta_0\", sigma=sigma_beta0, dims=\"group\")\n",
104+
" z = pm.Normal(\"z\", dims=(\"group\", \"knot\"))\n",
105+
"\n",
106+
" delta_factors = pt.special.softmax(z, axis=-1) # (groups, knot)\n",
107+
" slope_factors = 1 - pt.cumsum(delta_factors[:, :-1], axis=-1) # (groups, knot-1)\n",
108+
" spline_slopes = pt.join(-1, beta0[:, None], beta0[:, None] * slope_factors) # (groups, knot-1)\n",
109+
" beta = pt.join(-1, beta0[:, None], pt.diff(spline_slopes, axis=-1)) # (groups, knot)\n",
110+
"\n",
111+
" beta = pm.Deterministic(\"beta\", beta, dims=(\"group\", \"knot\"))\n",
112+
"\n",
113+
" X = pt.maximum(0, x[:, None] - knots[None, :]) # (n, knot)\n",
114+
" mu = (X * beta[group_idx]).sum(-1) # ((n, knots) * (n, knots)).sum(-1) = (n,)\n",
115+
" y = pm.Normal(\"y\", mu=mu, sigma=sigma, observed=y_np, dims=\"obs\")"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"id": "48d4d69fcc838be3",
122+
"metadata": {},
123+
"outputs": [],
124+
"source": [
125+
"with XModel(coords=coords) as xmodel:\n",
126+
" x = XData(\"x\", x_np, dims=\"obs\")\n",
127+
" knots = XData(\"knots\", knots_np, dims=\"knot\")\n",
128+
"\n",
129+
" sigma = pm.HalfCauchy(\"sigma\", beta=1)\n",
130+
" sigma_beta0 = pm.HalfNormal(\"sigma_beta0\", sigma=10)\n",
131+
" beta0 = pm.HalfNormal(\"beta_0\", sigma=sigma_beta0, dims=\"group\")\n",
132+
" z = pm.Normal(\"z\", dims=(\"group\", \"knot\"))\n",
133+
"\n",
134+
" delta_factors = px.special.softmax(z, dim=\"knot\")\n",
135+
" slope_factors = 1 - delta_factors.isel(knot=slice(None, -1)).cumsum(\"knot\")\n",
136+
" spline_slopes = px.concat([beta0, beta0 * slope_factors], dim=\"knot\")\n",
137+
" beta = px.concat([beta0, spline_slopes.diff(\"knot\")], dim=\"knot\")\n",
138+
"\n",
139+
" beta = pm.Deterministic(\"beta\", beta, dims=(\"group\", \"knot\"))\n",
140+
"\n",
141+
" X = px.math.scalar_maximum(0, x - knots)\n",
142+
" mu = (X * beta.isel(group=group_idx).rename(group=\"obs\")).sum(\"knot\")\n",
143+
" y_obs = pm.Normal(\"y_obs\", mu=mu.values, sigma=sigma, observed=y_np, dims=\"obs\")"
144+
]
145+
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": null,
149+
"id": "da17a5c329187db6",
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"print(f\"{model.compile_logp()(model.initial_point()):,}\")\n",
154+
"print(f\"{xmodel.compile_logp()(xmodel.initial_point()):,}\")"
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": null,
160+
"id": "85841107447a1ddd",
161+
"metadata": {},
162+
"outputs": [],
163+
"source": []
164+
}
165+
],
166+
"metadata": {},
167+
"nbformat": 4,
168+
"nbformat_minor": 5
169+
}

pymc/distributions/mixture.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
)
3737
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, rv_size_is_none
3838
from pymc.distributions.transforms import _default_transform
39-
from pymc.distributions.truncated import Truncated
4039
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob
4140
from pymc.logprob.basic import logp
4241
from pymc.logprob.transforms import IntervalTransform
@@ -809,13 +808,13 @@ def dist(cls, psi, mu=None, alpha=None, p=None, n=None, **kwargs):
809808
)
810809

811810

812-
class _MarginalHurdleRV(_BaseMixtureRV):
811+
class _HurdleRV(_BaseMixtureRV):
813812
pass
814813

815814

816815
class _Hurdle(_BaseMixtureDistribution):
817-
rv_type = _MarginalHurdleRV
818-
rv_op = _MarginalHurdleRV.rv_op
816+
rv_type = _HurdleRV
817+
rv_op = _HurdleRV.rv_op
819818

820819
@classmethod
821820
def _create(cls, *, name, nonzero_p, nonzero_dist, max_n_steps=10_000, **kwargs):
@@ -826,6 +825,8 @@ def _create(cls, *, name, nonzero_p, nonzero_dist, max_n_steps=10_000, **kwargs)
826825
In hurdle models, the zeros come from a completely different process than the rest of the data.
827826
In other words, the zeros are not inflated, they come from a different process.
828827
"""
828+
from pymc.distributions.truncated import Truncated
829+
829830
dtype = nonzero_dist.dtype
830831

831832
if dtype.startswith("int"):
@@ -848,12 +849,12 @@ def _create(cls, *, name, nonzero_p, nonzero_dist, max_n_steps=10_000, **kwargs)
848849
return cls.dist(weights, comp_dists, **kwargs)
849850

850851

851-
@_logprob.register(_MarginalHurdleRV)
852+
@_logprob.register(_HurdleRV)
852853
def marginal_hurdle_logprob(op, values, rng, weights, *components, **kwargs):
853854
(value,) = values
854855

855856
if len(components) != 2:
856-
raise TypeError(
857+
raise NotImplementedError(
857858
f"MarginalHurdleRV logp only supports 2 components, got {(len(components))}"
858859
)
859860

pymc/distributions/truncated.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@
3030
from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
3131
from pymc.distributions.dist_math import check_parameters
3232
from pymc.distributions.distribution import (
33+
DiracDeltaRV,
3334
Distribution,
3435
SymbolicRandomVariable,
3536
_support_point,
3637
support_point,
3738
)
39+
from pymc.distributions.mixture import _HurdleRV
3840
from pymc.distributions.shape_utils import (
3941
_change_dist_size,
4042
change_dist_size,
@@ -79,7 +81,9 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
7981

8082
# Try to use specialized Op
8183
try:
82-
return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs)
84+
return _truncated(
85+
dist.owner.op, lower, upper, size, *dist.owner.inputs, max_n_steps=max_n_steps
86+
)
8387
except NotImplementedError:
8488
pass
8589

@@ -222,7 +226,7 @@ def update(self, node: Apply):
222226

223227

224228
@singledispatch
225-
def _truncated(op: Op, lower, upper, size, *params):
229+
def _truncated(op: Op, lower, upper, size, *params, max_n_steps: int):
226230
"""Return the truncated equivalent of another `RandomVariable`."""
227231
raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented")
228232

@@ -307,13 +311,14 @@ def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs)
307311
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
308312
)
309313

310-
if (
311-
isinstance(dist.owner.op, SymbolicRandomVariable)
312-
and "[size]" not in dist.owner.op.extended_signature
314+
if isinstance(dist.owner.op, SymbolicRandomVariable) and not (
315+
"[size]" in dist.owner.op.extended_signature
316+
# If there's a specific _truncated dispatch for this RV, that's also fine
317+
or _truncated.dispatch(type(dist.owner.op)) is not _truncated.dispatch(object)
313318
):
314319
# Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole
315320
# random graph and as such we don't know where the actual inputs begin. This happens mostly for
316-
# distribution factories like `Censored` and `Mixture` which would have a very complex signature if they
321+
# distribution factories like `Censored` which would have a very complex signature if they
317322
# encapsulated the random components instead of taking them as inputs like they do now.
318323
# SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter.
319324
raise NotImplementedError(f"Truncation not implemented for {dist.owner.op}")
@@ -462,7 +467,7 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
462467

463468

464469
@_truncated.register(NormalRV)
465-
def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma):
470+
def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma, *, max_n_steps):
466471
return TruncatedNormal.dist(
467472
mu=mu,
468473
sigma=sigma,
@@ -472,3 +477,34 @@ def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma):
472477
size=size,
473478
dtype=op.dtype,
474479
)
480+
481+
482+
@_truncated.register(_HurdleRV)
483+
def _truncated_hurdle(op: _HurdleRV, lower, upper, size, rng, weights, *components, max_n_steps):
484+
if len(components) != 2:
485+
raise TypeError("Truncated HurdleRV only supports two components")
486+
487+
dirac_delta_dist, other_dist = components
488+
489+
if not isinstance(dirac_delta_dist.owner.op, DiracDeltaRV):
490+
raise TypeError("First component of HurdleRV must be a DiracDeltaRV")
491+
492+
# If the DiracDelta value is outside the truncation bounds, this is effectively a non-hurdle distribution
493+
# We achieve this by adjusting the weights of the DiracDelta component, so it's never selected in that case
494+
[dirac_delta_value] = dirac_delta_dist.owner.op.dist_params(dirac_delta_dist.owner)
495+
nonzero_p = weights[..., 1]
496+
lower_check = np.array(True) if lower is None else lower <= dirac_delta_value
497+
upper_check = np.array(True) if upper is None else dirac_delta_value <= upper
498+
adjusted_nonzero_p = pt.where(
499+
lower_check & upper_check,
500+
nonzero_p,
501+
1,
502+
)
503+
adjusted_weights = pt.stack([1 - adjusted_nonzero_p, adjusted_nonzero_p], axis=-1)
504+
505+
# The only remaining step is to truncate the other distribution
506+
truncated_dist = Truncated.dist(other_dist, lower=lower, upper=upper, max_n_steps=max_n_steps)
507+
508+
# Creating a hurdle with the adjusted weights and the truncated distribution
509+
# Should be equivalent to truncating the original hurdle distribution
510+
return op.rv_op(adjusted_weights, dirac_delta_dist, truncated_dist, size=size)

tests/distributions/test_mixture.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
Poisson,
5050
StickBreakingWeights,
5151
Triangular,
52+
Truncated,
5253
Uniform,
5354
ZeroInflatedBinomial,
5455
ZeroInflatedNegativeBinomial,
@@ -1710,3 +1711,30 @@ def logp_fn(value, psi, mu, sigma):
17101711
return np.log(psi) + st.lognorm.logpdf(value, sigma, 0, np.exp(mu))
17111712

17121713
check_logp(HurdleLogNormal, Rplus, {"psi": Unit, "mu": R, "sigma": Rplusbig}, logp_fn)
1714+
1715+
@pytest.mark.parametrize("lower", (-np.inf, 0, None, 1))
1716+
def test_truncated_hurdle_lognormal(self, lower):
1717+
psi = 0.7
1718+
x = HurdleLogNormal.dist(psi=psi, mu=3, sigma=1)
1719+
x_trunc = Truncated.dist(x, lower=lower, upper=30, size=(1000,))
1720+
1721+
x_trunc_draws = draw(x_trunc)
1722+
assert ((x_trunc_draws >= (lower or -np.inf)) & (x_trunc_draws <= 30)).all()
1723+
1724+
x_trunc = Truncated.dist(x, lower=lower, upper=30, size=(4,))
1725+
x_trunc_logp = logp(x_trunc, [0, 5.5, 30.0, 30.1]).eval()
1726+
effective_psi = psi if (lower or -np.inf) <= 0 else 1
1727+
np.testing.assert_allclose(
1728+
x_trunc_logp,
1729+
[
1730+
np.log(1 - effective_psi), # 0 is not in the support of the distribution
1731+
*(
1732+
np.log(effective_psi)
1733+
+ logp(
1734+
Truncated.dist(LogNormal.dist(mu=3, sigma=1), lower=lower, upper=30),
1735+
[5.5, 30.0],
1736+
)
1737+
).eval(),
1738+
-np.inf, # 30.1 is outside the upper bound
1739+
],
1740+
)

0 commit comments

Comments
 (0)