diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index d63d8c6617..6de89f8314 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -37,7 +37,7 @@ import abc from copy import copy -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import pytensor.tensor as at @@ -133,43 +133,55 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: ``Y`` on the natural scale. """ - rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) - values_to_transforms = getattr(fgraph, "values_to_transforms", None) + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + values_to_transforms: Optional[TransformValuesMapping] = getattr( + fgraph, "values_to_transforms", None + ) if rv_map_feature is None or values_to_transforms is None: return None # pragma: no cover - try: - rv_var = node.default_output() - rv_var_out_idx = node.outputs.index(rv_var) - except ValueError: - return None + rv_vars = [] + value_vars = [] - value_var = rv_map_feature.rv_values.get(rv_var, None) - if value_var is None: + for out in node.outputs: + value = rv_map_feature.rv_values.get(out, None) + if value is None: + continue + rv_vars.append(out) + value_vars.append(value) + + if not value_vars: return None - transform = values_to_transforms.get(value_var, None) + transforms = [values_to_transforms.get(value_var, None) for value_var in value_vars] - if transform is None: + if all(transform is None for transform in transforms): return None - new_op = _create_transformed_rv_op(node.op, transform) + new_op = _create_transformed_rv_op(node.op, transforms) # Create a new `Apply` node and outputs trans_node = node.clone() trans_node.op = new_op - trans_node.outputs[rv_var_out_idx].name = node.outputs[rv_var_out_idx].name # We now assume that the old value variable represents the *transformed space*. # This means that we need to replace all instance of the old value variable # with "inversely/un-" transformed versions of itself. - new_value_var = transformed_variable( - transform.backward(value_var, *trans_node.inputs), value_var - ) - if value_var.name and getattr(transform, "name", None): - new_value_var.name = f"{value_var.name}_{transform.name}" + for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms): + rv_var_out_idx = node.outputs.index(rv_var) + trans_node.outputs[rv_var_out_idx].name = rv_var.name - rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx]) + if transform is None: + continue + + new_value_var = transformed_variable( + transform.backward(value_var, *trans_node.inputs), value_var + ) + + if value_var.name and getattr(transform, "name", None): + new_value_var.name = f"{value_var.name}_{transform.name}" + + rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx]) return trans_node.outputs @@ -549,7 +561,7 @@ def log_jac_det(self, value, *inputs): def _create_transformed_rv_op( rv_op: Op, - transform: RVTransform, + transforms: Union[RVTransform, Sequence[Union[None, RVTransform]]], *, cls_dict_extra: Optional[Dict] = None, ) -> Op: @@ -572,14 +584,20 @@ def _create_transformed_rv_op( """ - trans_name = getattr(transform, "name", "transformed") + if not isinstance(transforms, Sequence): + transforms = (transforms,) + + trans_names = [ + getattr(transform, "name", "transformed") if transform is not None else "None" + for transform in transforms + ] rv_op_type = type(rv_op) rv_type_name = rv_op_type.__name__ cls_dict = rv_op_type.__dict__.copy() rv_name = cls_dict.get("name", "") if rv_name: - cls_dict["name"] = f"{rv_name}_{trans_name}" - cls_dict["transform"] = transform + cls_dict["name"] = f"{rv_name}_{'_'.join(trans_names)}" + cls_dict["transforms"] = transforms if cls_dict_extra is not None: cls_dict.update(cls_dict_extra) @@ -595,17 +613,27 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs): We assume that the value variable was back-transformed to be on the natural support of the respective random variable. """ - (value,) = values + logprobs = _logprob(rv_op, values, *inputs, **kwargs) - logprob = _logprob(rv_op, values, *inputs, **kwargs) + if not isinstance(logprobs, Sequence): + logprobs = [logprobs] if use_jacobian: - assert isinstance(value.owner.op, TransformedVariable) - original_forward_value = value.owner.inputs[1] - jacobian = op.transform.log_jac_det(original_forward_value, *inputs) - logprob += jacobian - - return logprob + assert len(values) == len(logprobs) == len(op.transforms) + logprobs_jac = [] + for value, transform, logprob in zip(values, op.transforms, logprobs): + if transform is None: + logprobs_jac.append(logprob) + continue + assert isinstance(value.owner.op, TransformedVariable) + original_forward_value = value.owner.inputs[1] + jacobian = transform.log_jac_det(original_forward_value, *inputs).copy() + if value.name: + jacobian.name = f"{value.name}_jacobian" + logprobs_jac.append(logprob + jacobian) + logprobs = logprobs_jac + + return logprobs new_op = copy(rv_op) new_op.__class__ = new_op_type diff --git a/pymc/tests/logprob/test_transforms.py b/pymc/tests/logprob/test_transforms.py index 3bb0dc5d6b..474f8e4bb7 100644 --- a/pymc/tests/logprob/test_transforms.py +++ b/pymc/tests/logprob/test_transforms.py @@ -42,10 +42,12 @@ import scipy.special from numdifftools import Jacobian +from pytensor.compile.builders import OpFromGraph from pytensor.graph.basic import equal_computations from pytensor.graph.fg import FunctionGraph from pymc.distributions.transforms import _default_transform, log, logodds +from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logprob from pymc.logprob.transforms import ( ChainedTransform, @@ -437,6 +439,66 @@ def test_default_transform_multiout(): ) +@pytest.fixture(scope="module") +def multiout_measurable_op(): + # Create a dummy Op that just returns the two inputs + mu1, mu2 = at.scalars("mu1", "mu2") + + class TestOpFromGraph(OpFromGraph): + def do_constant_folding(self, fgraph, node): + False + + multiout_op = TestOpFromGraph([mu1, mu2], [mu1 + 0.0, mu2 + 0.0]) + + MeasurableVariable.register(TestOpFromGraph) + + @_logprob.register(TestOpFromGraph) + def logp_multiout(op, values, mu1, mu2): + value1, value2 = values + return value1 + mu1, value2 + mu2 + + @_get_measurable_outputs.register(TestOpFromGraph) + def measurable_multiout_op_outputs(op, node): + return node.outputs + + return multiout_op + + +@pytest.mark.parametrize("transform_x", (True, False)) +@pytest.mark.parametrize("transform_y", (True, False)) +def test_nondefault_transform_multiout(transform_x, transform_y, multiout_measurable_op): + x, y = multiout_measurable_op(1, 2) + x.name = "x" + y.name = "y" + x_vv = x.clone() + y_vv = y.clone() + + transform_rewrite = TransformValuesRewrite( + { + x_vv: LogTransform() if transform_x else None, + y_vv: ExpTransform() if transform_y else None, + } + ) + + logp = joint_logprob({x: x_vv, y: y_vv}, extra_rewrites=transform_rewrite) + + x_vv_test = np.random.normal() + y_vv_test = np.abs(np.random.normal()) + + expected_logp = 0 + if not transform_x: + expected_logp += x_vv_test + 1 + else: + expected_logp += np.exp(x_vv_test) + 1 + x_vv_test + # y logp + if not transform_y: + expected_logp += y_vv_test + 2 + else: + expected_logp += np.log(y_vv_test) + 2 - np.log(y_vv_test) + + np.testing.assert_almost_equal(logp.eval({x_vv: x_vv_test, y_vv: y_vv_test}), expected_logp) + + def test_TransformValuesMapping(): x = at.vector() fg = FunctionGraph(outputs=[x])