Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented rng_fn to CAR/ICAR #7723

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2354,8 +2354,59 @@ def __call__(self, W, sigma, zero_sum_stdev, size=None, **kwargs):
return super().__call__(W, sigma, zero_sum_stdev, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, size, W, sigma, zero_sum_stdev):
raise NotImplementedError("Cannot sample from ICAR prior")
def rng_fn(cls, rng, W, sigma, zero_sum_stdev, size=None):
"""Sample from the ICAR distribution.

The ICAR distribution is a special case of the CAR distribution with alpha=1.
It generates spatial random effects where neighboring areas tend to have
similar values. The precision matrix is the graph Laplacian of W.

Parameters
----------
rng : numpy.random.Generator
Random number generator
W : ndarray
Symmetric adjacency matrix
sigma : float
Standard deviation parameter
zero_sum_stdev : float
Controls how strongly to enforce the zero-sum constraint
size : tuple, optional
Size of the samples to generate

Returns
-------
ndarray
Samples from the ICAR distribution
"""
W = np.asarray(W)
N = W.shape[0]

# Construct the precision matrix (graph Laplacian)
D = np.diag(W.sum(axis=1))
Q = D - W

# Add regularization for the zero eigenvalue based on zero_sum_stdev
zero_sum_precision = 1.0 / (zero_sum_stdev * N)**2
Q_reg = Q + zero_sum_precision * np.ones((N, N)) / N

# Use eigendecomposition to handle the degenerate covariance
eigvals, eigvecs = np.linalg.eigh(Q_reg)

# Construct the covariance matrix
cov = eigvecs @ np.diag(1.0 / eigvals) @ eigvecs.T

# Scale by sigma^2
cov = sigma**2 * cov

# Generate samples
mean = np.zeros(N)

# Handle different size specifications
if size is None:
return rng.multivariate_normal(mean, cov)
else:
return rng.multivariate_normal(mean, cov, size=size)


icar = ICARRV()
Expand Down
Loading