Skip to content

[Bug] Jit related gradient error #3444

@nomadbl

Description

@nomadbl

Hello everyone. As yet another Jit related bug, I'm posting this mainly for other people encountering this, and hopefully help make Jit viable in future.
JitTrace_ELBO produces gradient related errors when using MultivariateNormal distribution parametrized by scale_tril. I haven't tested other combinations of inputs, but provide an MWE below that reproduces the error on my system and on google colab (os and package versions quoted below). The same MWE does not produce the error if Trace_ELBO is used.

Environment

  • OS Ubuntu 22.04, python version: 3.11.13 (main, Jun 4 2025, 08:57:29) [GCC 11.4.0]
  • PyTorch version: 2.6.0+cu124
  • Pyro version: 1.9.1

MWE

pyro.clear_param_store()
def model():
  s = pyro.sample("s", pyro.distributions.LKJCholesky(2, 1))
  d = pyro.sample("d", pyro.distributions.Gamma(torch.ones(2), torch.ones(2)).to_event(1))
  pyro.sample("x", pyro.distributions.MultivariateNormal(torch.zeros(2), scale_tril=s*torch.sqrt(d[:,None])))

def guide():
  eta = pyro.param("eta", lambda: torch.tensor(1, dtype=torch.float32))
  alpha = pyro.param("alpha", lambda: 11*torch.ones(2, dtype=torch.float32))
  sigma = pyro.param("sigma", lambda: 10*torch.ones(2, dtype=torch.float32))
  sx = pyro.param("sx", lambda: pyro.distributions.LKJCholesky(2, 1).sample())
  mu_x = pyro.param("mu_x", torch.zeros(2, dtype=torch.float32))

  s = pyro.sample("s", pyro.distributions.LKJCholesky(2, eta))
  d = pyro.sample("d", pyro.distributions.Gamma(alpha, sigma).to_event(1))
  pyro.sample("x", pyro.distributions.MultivariateNormal(mu_x, scale_tril=sx))

elbo = pyro.infer.JitTrace_ELBO()
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)}), loss=elbo)
svi.step()

Error output

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [2, 2]], which is output 0 of AddBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions