-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Open
Description
Hello everyone. As yet another Jit related bug, I'm posting this mainly for other people encountering this, and hopefully help make Jit viable in future.
JitTrace_ELBO produces gradient related errors when using MultivariateNormal distribution parametrized by scale_tril. I haven't tested other combinations of inputs, but provide an MWE below that reproduces the error on my system and on google colab (os and package versions quoted below). The same MWE does not produce the error if Trace_ELBO is used.
Environment
- OS Ubuntu 22.04, python version: 3.11.13 (main, Jun 4 2025, 08:57:29) [GCC 11.4.0]
- PyTorch version: 2.6.0+cu124
- Pyro version: 1.9.1
MWE
pyro.clear_param_store()
def model():
s = pyro.sample("s", pyro.distributions.LKJCholesky(2, 1))
d = pyro.sample("d", pyro.distributions.Gamma(torch.ones(2), torch.ones(2)).to_event(1))
pyro.sample("x", pyro.distributions.MultivariateNormal(torch.zeros(2), scale_tril=s*torch.sqrt(d[:,None])))
def guide():
eta = pyro.param("eta", lambda: torch.tensor(1, dtype=torch.float32))
alpha = pyro.param("alpha", lambda: 11*torch.ones(2, dtype=torch.float32))
sigma = pyro.param("sigma", lambda: 10*torch.ones(2, dtype=torch.float32))
sx = pyro.param("sx", lambda: pyro.distributions.LKJCholesky(2, 1).sample())
mu_x = pyro.param("mu_x", torch.zeros(2, dtype=torch.float32))
s = pyro.sample("s", pyro.distributions.LKJCholesky(2, eta))
d = pyro.sample("d", pyro.distributions.Gamma(alpha, sigma).to_event(1))
pyro.sample("x", pyro.distributions.MultivariateNormal(mu_x, scale_tril=sx))
elbo = pyro.infer.JitTrace_ELBO()
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)}), loss=elbo)
svi.step()
Error output
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [2, 2]], which is output 0 of AddBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Metadata
Metadata
Assignees
Labels
No labels