Skip to content

[bug] pyro.render_model omits edges #3441

@YardenRachamim

Description

@YardenRachamim

Issue Description

context: linear ordinal regression model using OrderedLogistic distribution
problem: pyro.render_model fails to draw edges when a child pyro.sample site
receives a tensor produced by either

  • a pyro.deterministic site, or
  • a pyro.sample site wrapped in a Delta transform (e.g. to enforce ordering).

That makes the rendered DAG incorrect parents are silently hidden.
I think the inference worked fine but can't tell for sure.

Environment

  • OS: Ununto
  • python version: 3.10.
  • PyTorch version: 2.6.0+cpu
  • Pyro version: 1.9.1

Code Snippet

def ordinal_regression_model(X, y=None):
    D = X.size(-1)
    num_categories = 3

    cutpoints = pyro.param(
        "cutpoints",
        torch.linspace(-1., 1., num_categories - 1, dtype=X.dtype, device=X.device),
        constraint=constraints.ordered_vector
    )
#     cutpoints = pyro.sample(
#         "cutpoints",
#         dist.TransformedDistribution(                     # base → transform → ordered
#             dist.Normal(
#                 torch.zeros(num_categories - 1, device=X.device, dtype=X.dtype),
#                 torch.ones(num_categories - 1, device=X.device, dtype=X.dtype),
#             ).to_event(1),                                # joint prior over the vector
#             OrderedTransform(),                # enforces monotonicity
#         ),
#     )
    
    loc = torch.zeros(D, dtype=X.dtype, device=X.device)
    scale = torch.ones_like(loc) * 5
    beta = pyro.sample("beta", dist.Normal(loc=loc, scale=scale).to_event(1))  # D
#     eta = pyro.deterministic("eta", X.float() @ beta.float())
    eta = pyro.sample(          # now a sample site
        "eta",
        dist.Delta((X.float() @ beta.float())).to_event(1)
    )
    with pyro.plate("data", X.size(0)):
        pyro.sample(
            "obs",
            dist.OrderedLogistic(eta, cutpoints),  # OrderedLogistic does its own broadcast
            obs=y
        )

pyro.render_model(
    ordinal_regression_model,
    model_args=(batch[0], batch[1]),
    render_params=True,
    render_distributions=True,
    render_deterministic=True
)

Expected Behaviour

Graph should explicitly show the dependency chain: beta → eta → obs
including the plate context and cutpoints as parameters. Any tensor passed as a parameter to a sample site's distribution, regardless of being computed via pyro.deterministic or via a Delta-transformed sample, should be considered as a parent when computing edges.

Note: Attempted workarounds such as:

  • Wrapping the tensor explicitly with pyro.deterministic.
  • Using a pyro.sample site with a Delta distribution.

All attempts listed above still failed to display the missing edge.

Actual Behaviour

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions