Skip to content

Pyro sets all Parameters to requires_grad=True, when using PyroModule #3438

@heborras

Description

@heborras

Issue Description

Hi all,
apparently the bug initially seen in #1778 reappears, when using a PyroModule instead and running a simple inference.

Meaning that parameters which have requires_grad=False set will still be added to the param_store and additionally be set to requires_grad=True, making them learnable, even when they should be not.

Environment

For any bugs, please provide the following:

  • Debian, Python 3.11.6
  • PyTorch version: 2.2.1
  • Pyro version: 1.9.1

Code Snippet

Here a quick example based on the already existing test for the previous issue:

import torch
from torch.nn import Parameter

import pyro
from pyro.nn import PyroModule
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDelta

# Create a generic PyroModule with some non-learnable parameters
class net_class(PyroModule):
    def __init__(self):
        super().__init__()
        self.x = Parameter(torch.zeros(1))
        self.y = Parameter(torch.zeros(1), requires_grad=False)
    
    def forward(self, s, y=None):
        s += self.x
        s += self.y
        
        return s

# Initialize model
net = net_class()

# Create guide
guide = AutoDelta(net)

# Create optimizer
optimizer = torch.optim.Adam
scheduler = pyro.optim.CosineAnnealingLR({'optimizer': optimizer, 'optim_args': {'lr': 1e-3}, 'T_max': 250})

# Create loss
loss_method = Trace_ELBO(
    num_particles=1,
)

# Create svi object
svi = SVI(net, guide, scheduler, loss=loss_method)

# Run an eval step to populate the param_store
net.eval()
loss = svi.evaluate_loss([1,2], [1,1])

# Test for parameters in param_store
print('Parameters in Pyro param_store:')
for key in pyro.get_param_store().keys():
    print(key, pyro.get_param_store()[key])

assert "x" in pyro.get_param_store().keys()
assert "y" not in pyro.get_param_store().keys()

This will print:

Parameters in Pyro param_store:
x Parameter containing:
tensor([0.], requires_grad=True)
y Parameter containing:
tensor([0.], requires_grad=True)

Which means that y is now learnable and part of the param_store, even though it should be neither.

Potentially a similar workaround as in #1779 could work here as well?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions