Skip to content

Allow transforms to work with multiple-valued nodes #6341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 60 additions & 32 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import abc

from copy import copy
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import pytensor.tensor as at

Expand Down Expand Up @@ -133,43 +133,55 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
``Y`` on the natural scale.
"""

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
values_to_transforms = getattr(fgraph, "values_to_transforms", None)
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
values_to_transforms: Optional[TransformValuesMapping] = getattr(
fgraph, "values_to_transforms", None
)

if rv_map_feature is None or values_to_transforms is None:
return None # pragma: no cover

try:
rv_var = node.default_output()
rv_var_out_idx = node.outputs.index(rv_var)
except ValueError:
return None
rv_vars = []
value_vars = []

value_var = rv_map_feature.rv_values.get(rv_var, None)
if value_var is None:
for out in node.outputs:
value = rv_map_feature.rv_values.get(out, None)
if value is None:
continue
rv_vars.append(out)
value_vars.append(value)

if not value_vars:
return None

transform = values_to_transforms.get(value_var, None)
transforms = [values_to_transforms.get(value_var, None) for value_var in value_vars]

if transform is None:
if all(transform is None for transform in transforms):
return None

new_op = _create_transformed_rv_op(node.op, transform)
new_op = _create_transformed_rv_op(node.op, transforms)
# Create a new `Apply` node and outputs
trans_node = node.clone()
trans_node.op = new_op
trans_node.outputs[rv_var_out_idx].name = node.outputs[rv_var_out_idx].name

# We now assume that the old value variable represents the *transformed space*.
# This means that we need to replace all instance of the old value variable
# with "inversely/un-" transformed versions of itself.
new_value_var = transformed_variable(
transform.backward(value_var, *trans_node.inputs), value_var
)
if value_var.name and getattr(transform, "name", None):
new_value_var.name = f"{value_var.name}_{transform.name}"
for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms):
rv_var_out_idx = node.outputs.index(rv_var)
trans_node.outputs[rv_var_out_idx].name = rv_var.name

rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])
if transform is None:
continue

new_value_var = transformed_variable(
transform.backward(value_var, *trans_node.inputs), value_var
)

if value_var.name and getattr(transform, "name", None):
new_value_var.name = f"{value_var.name}_{transform.name}"

rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])

return trans_node.outputs

Expand Down Expand Up @@ -549,7 +561,7 @@ def log_jac_det(self, value, *inputs):

def _create_transformed_rv_op(
rv_op: Op,
transform: RVTransform,
transforms: Union[RVTransform, Sequence[Union[None, RVTransform]]],
*,
cls_dict_extra: Optional[Dict] = None,
) -> Op:
Expand All @@ -572,14 +584,20 @@ def _create_transformed_rv_op(

"""

trans_name = getattr(transform, "name", "transformed")
if not isinstance(transforms, Sequence):
transforms = (transforms,)

trans_names = [
getattr(transform, "name", "transformed") if transform is not None else "None"
for transform in transforms
]
rv_op_type = type(rv_op)
rv_type_name = rv_op_type.__name__
cls_dict = rv_op_type.__dict__.copy()
rv_name = cls_dict.get("name", "")
if rv_name:
cls_dict["name"] = f"{rv_name}_{trans_name}"
cls_dict["transform"] = transform
cls_dict["name"] = f"{rv_name}_{'_'.join(trans_names)}"
cls_dict["transforms"] = transforms

if cls_dict_extra is not None:
cls_dict.update(cls_dict_extra)
Expand All @@ -595,17 +613,27 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
We assume that the value variable was back-transformed to be on the natural
support of the respective random variable.
"""
(value,) = values
logprobs = _logprob(rv_op, values, *inputs, **kwargs)

logprob = _logprob(rv_op, values, *inputs, **kwargs)
if not isinstance(logprobs, Sequence):
logprobs = [logprobs]

if use_jacobian:
assert isinstance(value.owner.op, TransformedVariable)
original_forward_value = value.owner.inputs[1]
jacobian = op.transform.log_jac_det(original_forward_value, *inputs)
logprob += jacobian

return logprob
assert len(values) == len(logprobs) == len(op.transforms)
logprobs_jac = []
for value, transform, logprob in zip(values, op.transforms, logprobs):
if transform is None:
logprobs_jac.append(logprob)
continue
assert isinstance(value.owner.op, TransformedVariable)
original_forward_value = value.owner.inputs[1]
jacobian = transform.log_jac_det(original_forward_value, *inputs).copy()
if value.name:
jacobian.name = f"{value.name}_jacobian"
logprobs_jac.append(logprob + jacobian)
logprobs = logprobs_jac

return logprobs

new_op = copy(rv_op)
new_op.__class__ = new_op_type
Expand Down
62 changes: 62 additions & 0 deletions pymc/tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@
import scipy.special

from numdifftools import Jacobian
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph

from pymc.distributions.transforms import _default_transform, log, logodds
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob
from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logprob
from pymc.logprob.transforms import (
ChainedTransform,
Expand Down Expand Up @@ -437,6 +439,66 @@ def test_default_transform_multiout():
)


@pytest.fixture(scope="module")
def multiout_measurable_op():
# Create a dummy Op that just returns the two inputs
mu1, mu2 = at.scalars("mu1", "mu2")

class TestOpFromGraph(OpFromGraph):
def do_constant_folding(self, fgraph, node):
False

multiout_op = TestOpFromGraph([mu1, mu2], [mu1 + 0.0, mu2 + 0.0])

MeasurableVariable.register(TestOpFromGraph)

@_logprob.register(TestOpFromGraph)
def logp_multiout(op, values, mu1, mu2):
value1, value2 = values
return value1 + mu1, value2 + mu2

@_get_measurable_outputs.register(TestOpFromGraph)
def measurable_multiout_op_outputs(op, node):
return node.outputs

return multiout_op


@pytest.mark.parametrize("transform_x", (True, False))
@pytest.mark.parametrize("transform_y", (True, False))
def test_nondefault_transform_multiout(transform_x, transform_y, multiout_measurable_op):
x, y = multiout_measurable_op(1, 2)
x.name = "x"
y.name = "y"
x_vv = x.clone()
y_vv = y.clone()

transform_rewrite = TransformValuesRewrite(
{
x_vv: LogTransform() if transform_x else None,
y_vv: ExpTransform() if transform_y else None,
}
)

logp = joint_logprob({x: x_vv, y: y_vv}, extra_rewrites=transform_rewrite)

x_vv_test = np.random.normal()
y_vv_test = np.abs(np.random.normal())

expected_logp = 0
if not transform_x:
expected_logp += x_vv_test + 1
else:
expected_logp += np.exp(x_vv_test) + 1 + x_vv_test
# y logp
if not transform_y:
expected_logp += y_vv_test + 2
else:
expected_logp += np.log(y_vv_test) + 2 - np.log(y_vv_test)

np.testing.assert_almost_equal(logp.eval({x_vv: x_vv_test, y_vv: y_vv_test}), expected_logp)


def test_TransformValuesMapping():
x = at.vector()
fg = FunctionGraph(outputs=[x])
Expand Down