-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Description
Issue Description
context: linear ordinal regression model using OrderedLogistic distribution
problem: pyro.render_model fails to draw edges when a child pyro.sample site
receives a tensor produced by either
- a
pyro.deterministicsite, or - a
pyro.samplesite wrapped in aDeltatransform (e.g. to enforce ordering).
That makes the rendered DAG incorrect parents are silently hidden.
I think the inference worked fine but can't tell for sure.
Environment
- OS: Ununto
- python version: 3.10.
- PyTorch version: 2.6.0+cpu
- Pyro version: 1.9.1
Code Snippet
def ordinal_regression_model(X, y=None):
D = X.size(-1)
num_categories = 3
cutpoints = pyro.param(
"cutpoints",
torch.linspace(-1., 1., num_categories - 1, dtype=X.dtype, device=X.device),
constraint=constraints.ordered_vector
)
# cutpoints = pyro.sample(
# "cutpoints",
# dist.TransformedDistribution( # base → transform → ordered
# dist.Normal(
# torch.zeros(num_categories - 1, device=X.device, dtype=X.dtype),
# torch.ones(num_categories - 1, device=X.device, dtype=X.dtype),
# ).to_event(1), # joint prior over the vector
# OrderedTransform(), # enforces monotonicity
# ),
# )
loc = torch.zeros(D, dtype=X.dtype, device=X.device)
scale = torch.ones_like(loc) * 5
beta = pyro.sample("beta", dist.Normal(loc=loc, scale=scale).to_event(1)) # D
# eta = pyro.deterministic("eta", X.float() @ beta.float())
eta = pyro.sample( # now a sample site
"eta",
dist.Delta((X.float() @ beta.float())).to_event(1)
)
with pyro.plate("data", X.size(0)):
pyro.sample(
"obs",
dist.OrderedLogistic(eta, cutpoints), # OrderedLogistic does its own broadcast
obs=y
)
pyro.render_model(
ordinal_regression_model,
model_args=(batch[0], batch[1]),
render_params=True,
render_distributions=True,
render_deterministic=True
)
Expected Behaviour
Graph should explicitly show the dependency chain: beta → eta → obs
including the plate context and cutpoints as parameters. Any tensor passed as a parameter to a sample site's distribution, regardless of being computed via pyro.deterministic or via a Delta-transformed sample, should be considered as a parent when computing edges.
Note: Attempted workarounds such as:
- Wrapping the tensor explicitly with
pyro.deterministic. - Using a
pyro.samplesite with aDeltadistribution.
All attempts listed above still failed to display the missing edge.
Actual Behaviour
