-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Open
Labels
Description
Issue Description
Hi all,
apparently the bug initially seen in #1778 reappears, when using a PyroModule instead and running a simple inference.
Meaning that parameters which have requires_grad=False set will still be added to the param_store and additionally be set to requires_grad=True, making them learnable, even when they should be not.
Environment
For any bugs, please provide the following:
- Debian, Python 3.11.6
- PyTorch version: 2.2.1
- Pyro version: 1.9.1
Code Snippet
Here a quick example based on the already existing test for the previous issue:
import torch
from torch.nn import Parameter
import pyro
from pyro.nn import PyroModule
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDelta
# Create a generic PyroModule with some non-learnable parameters
class net_class(PyroModule):
def __init__(self):
super().__init__()
self.x = Parameter(torch.zeros(1))
self.y = Parameter(torch.zeros(1), requires_grad=False)
def forward(self, s, y=None):
s += self.x
s += self.y
return s
# Initialize model
net = net_class()
# Create guide
guide = AutoDelta(net)
# Create optimizer
optimizer = torch.optim.Adam
scheduler = pyro.optim.CosineAnnealingLR({'optimizer': optimizer, 'optim_args': {'lr': 1e-3}, 'T_max': 250})
# Create loss
loss_method = Trace_ELBO(
num_particles=1,
)
# Create svi object
svi = SVI(net, guide, scheduler, loss=loss_method)
# Run an eval step to populate the param_store
net.eval()
loss = svi.evaluate_loss([1,2], [1,1])
# Test for parameters in param_store
print('Parameters in Pyro param_store:')
for key in pyro.get_param_store().keys():
print(key, pyro.get_param_store()[key])
assert "x" in pyro.get_param_store().keys()
assert "y" not in pyro.get_param_store().keys()This will print:
Parameters in Pyro param_store:
x Parameter containing:
tensor([0.], requires_grad=True)
y Parameter containing:
tensor([0.], requires_grad=True)
Which means that y is now learnable and part of the param_store, even though it should be neither.
Potentially a similar workaround as in #1779 could work here as well?