Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions pyro/params/param_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def get_param(
init_tensor: Optional[torch.Tensor] = None,
constraint: constraints.Constraint = constraints.real,
event_dim: Optional[int] = None,
parametrization: Optional[str] = None,
) -> torch.Tensor:
"""
Get parameter from its name. If it does not yet exist in the
Expand All @@ -246,9 +247,19 @@ def get_param(
:rtype: torch.Tensor
"""
if init_tensor is None:
return self[name]
param = self[name]
else:
return self.setdefault(name, init_tensor, constraint)
param = self.setdefault(name, init_tensor, constraint)
# Apply parametrization if requested
if parametrization == "orthogonal":
import torch.nn.utils.parametrizations as parametrizations

if (
not hasattr(param, "parametrizations")
or "orthogonal" not in param.parametrizations
):
param = parametrizations.orthogonal(param)
return param

def match(self, name: str) -> Dict[str, torch.Tensor]:
"""
Expand Down
9 changes: 8 additions & 1 deletion pyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def param(
init_tensor: Union[torch.Tensor, Callable[[], torch.Tensor], None] = None,
constraint: constraints.Constraint = constraints.real,
event_dim: Optional[int] = None,
parametrization: Optional[str] = None,
) -> torch.Tensor:
"""
Saves the variable as a parameter in the param store.
Expand Down Expand Up @@ -86,7 +87,13 @@ def param(
"""
# Note effectful(-) requires the double passing of name below.
args = (name,) if init_tensor is None else (name, init_tensor)
value = _param(*args, constraint=constraint, event_dim=event_dim, name=name)
value = _param(
*args,
constraint=constraint,
event_dim=event_dim,
name=name,
parametrization=parametrization,
)
assert value is not None # type narrowing guaranteed by _param
return value

Expand Down
42 changes: 42 additions & 0 deletions tests/params/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,45 @@ def check_constraint(name):
assert_equal(pyro.param("z"), z0)
check_constraint("x0")
check_constraint("z0")


def test_get_param_behaviour():
"""Tests for ParamStoreDict.get_param: missing/no-init raises, init creates param, parametrization accepted.

This covers the following behaviors:
- calling get_param without init on a missing name raises KeyError
- calling get_param with an init tensor registers the parameter and returns the constrained value
- requesting parametrization 'orthogonal' returns a tensor of the requested shape and grad enabled
"""
param_store = pyro.get_param_store()
param_store.clear()

# missing without init should raise
raised = False
try:
param_store.get_param("missing_without_init")
except KeyError:
raised = True
assert raised

# with init and a positive constraint: param is created and returns the constrained value
init = 2.0 * torch.ones(2, 3)
p = param_store.get_param(
"p_with_init", init_tensor=init, constraint=constraints.positive
)
assert "p_with_init" in param_store
# returned constrained value should equal the init
assert_equal(p, init)
# the stored unconstrained value should equal log(init) for positive constraint
stored_unconstrained = param_store._params["p_with_init"]
expected_unconstrained = torch.log(init)
assert torch.allclose(stored_unconstrained, expected_unconstrained)

# requesting an orthogonal parametrization should not error and should return a tensor
param_store.clear()
p2_init = torch.randn(4, 4)
p2 = param_store.get_param("p2", init_tensor=p2_init, parametrization="orthogonal")
assert isinstance(p2, torch.Tensor)
assert p2.shape == p2_init.shape
# parametrized tensors should still require grad
assert p2.requires_grad
Loading