diff --git a/notebooks/auto-center-uncenter.ipynb b/notebooks/auto-center-uncenter.ipynb new file mode 100644 index 0000000000..e0510c4bf2 --- /dev/null +++ b/notebooks/auto-center-uncenter.ipynb @@ -0,0 +1,496 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "112b01e8-9263-43c6-a7f7-ed486e914b5d", + "metadata": {}, + "source": [ + "import numpy as np\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "\n", + "from pytensor.graph import FunctionGraph\n", + "\n", + "import pymc as pm\n", + "\n", + "from pymc.model.transform.conditioning import remove_value_transforms\n", + "from pymc.pytensorf import toposort_replace" + ], + "outputs": [], + "execution_count": 1 + }, + { + "cell_type": "code", + "id": "54eb55c842adb7d9", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:11.416319Z", + "start_time": "2025-01-14T13:03:11.413613Z" + } + }, + "source": [ + "# %load_ext autoreload\n", + "# %autoreload 2" + ], + "outputs": [], + "execution_count": 2 + }, + { + "cell_type": "code", + "id": "656ab86e-7a7a-4afe-8423-8ee2a6ff89fa", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:11.514485Z", + "start_time": "2025-01-14T13:03:11.509320Z" + } + }, + "source": [ + "class CenterTransform(pm.distributions.transforms.Transform):\n", + " ndim_supp = 0\n", + "\n", + " name = \"CenterTransform\"\n", + "\n", + " def __init__(self, trafo_param, sigma_fn=lambda args: args[-1]):\n", + " self._trafo_params = (trafo_param,)\n", + " self._trafo_param = trafo_param\n", + " self.sigma_fn = sigma_fn\n", + "\n", + " def get_sigma(self, args):\n", + " # *rv_params, hyper = args\n", + " return self.sigma_fn(args)\n", + "\n", + " def get_hyperparam(self):\n", + " return pt.sigmoid(self._trafo_param)\n", + "\n", + " def forward(self, x, *params):\n", + " sigma = self.get_sigma(params)\n", + " hyper = self.get_hyperparam()\n", + " return x / (sigma**hyper)\n", + "\n", + " def backward(self, y, *params):\n", + " sigma = self.get_sigma(params)\n", + " hyper = self.get_hyperparam()\n", + " return y * (sigma**hyper)\n", + "\n", + "\n", + "def forward_and_grad(transform, constrained_point, constrained_grad, *params):\n", + " if transform is None:\n", + " return constrained_point.copy(), constrained_grad.copy(), 0.0\n", + "\n", + " unconstrained_point = transform.forward(constrained_point, *params)\n", + " # Redefine forward so that L_op considers these separate path\n", + " backward_log_jac_det = transform.log_jac_det(\n", + " transform.forward(constrained_point, *params), *params\n", + " )\n", + " unconstrained_grad = pytensor.gradient.Lop(\n", + " f=[unconstrained_point, backward_log_jac_det],\n", + " wrt=constrained_point,\n", + " eval_points=[constrained_grad, pt.ones_like(backward_log_jac_det)],\n", + " )\n", + "\n", + " return unconstrained_point, unconstrained_grad, -backward_log_jac_det" + ], + "outputs": [], + "execution_count": 3 + }, + { + "cell_type": "markdown", + "id": "6814b3de-5287-4077-aad7-951039cd023f", + "metadata": {}, + "source": [ + "Steps:\n", + "\n", + "Split the freeRVs into two groups:\n", + " - learnable: RVs that have hyper parameters and are bijections\n", + " - constant: Remaining\n", + "\n", + "In the following, treat the constant trafos as usual, so always transform.\n", + "\n", + "Compile functions\n", + "\n", + "- new_transformation:\n", + " Given a point in the untransformed parameter space, initialize the hyper parameters\n", + " and store them in an array.\n", + "- transform_position_and_gradient:\n", + " Given a hyper parameter vector and untransformed position and gradient\n", + " as vector, compute the transformed position and gradient as vector. Also\n", + " compute the sum of logdets of the transforms.\n", + "- init_from_untransformed:\n", + " Given a hyper parameter vector and an untransformed point, compute\n", + " - The untransformed total logp and gradient.\n", + " Reuse transform_position_and_gradient to get:\n", + " - The transformed point and gradient as vectors and the sum of all logdets.\n", + " \n", + "- init_from_transformed:\n", + " Given a hyper parameter vector and a transformed positon, compute the other three.\n", + " Also compute the total logp and the sum of all logdets.\n", + "\n", + "- update_transformation:\n", + " Given a set of points and gradients on the untransformed space, return optimized hyper params." + ] + }, + { + "cell_type": "code", + "id": "d599592eab8d8a77", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:11.572345Z", + "start_time": "2025-01-14T13:03:11.551615Z" + } + }, + "source": [ + "with pm.Model() as unconstrained_model:\n", + " sigma = pm.HalfNormal(\"sigma\")\n", + "\n", + " x_hyper = pytensor.shared(np.array(0.5), name=\"x_hyper\")\n", + " trafo = CenterTransform(x_hyper)\n", + " pm.Normal(\"x\", mu=0, sigma=sigma, transform=trafo, shape=(3,))" + ], + "outputs": [], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:12.278619Z", + "start_time": "2025-01-14T13:03:11.600332Z" + } + }, + "cell_type": "code", + "source": [ + "constrained_point_value = {\"sigma\": np.exp(-0.3), \"x\": [-1.0, 0.0, 1.0]}\n", + "\n", + "constrained_model = remove_value_transforms(unconstrained_model)\n", + "\n", + "constrained_model_logp_value = constrained_model.compile_logp()(constrained_point_value)\n", + "\n", + "raveled_dlogp = constrained_model.compile_dlogp()(constrained_point_value)\n", + "constrained_point_grad_value = {\n", + " \"sigma\": np.asarray(raveled_dlogp[0]),\n", + " \"x\": raveled_dlogp[1:],\n", + "}" + ], + "id": "9241d794e112d2ac", + "outputs": [], + "execution_count": 5 + }, + { + "cell_type": "code", + "id": "17b1fe0d8e7f4ba3", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:12.300498Z", + "start_time": "2025-01-14T13:03:12.294451Z" + } + }, + "source": [ + "constrained_point_value, constrained_point_grad_value, constrained_model_logp_value" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "({'sigma': 0.7408182206817179, 'x': [-1.0, 0.0, 1.0]},\n", + " {'sigma': array(0.12881158),\n", + " 'x': array([ 1.8221188, -0. , -1.8221188])},\n", + " array(-4.17913157))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 6 + }, + { + "cell_type": "code", + "id": "569efd3f850f197f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:12.507094Z", + "start_time": "2025-01-14T13:03:12.352471Z" + } + }, + "source": [ + "logp_f = unconstrained_model.compile_logp()\n", + "[x_hyper] = logp_f.f.get_shared()\n", + "\n", + "# ip = unconstrained_model.initial_point()\n", + "ip = {\"sigma_log__\": np.array(-0.3), \"x_CenterTransform__\": np.array([0.0, 0.0, 0.0])}\n", + "\n", + "x_hyper.set_value(np.array(0.5))\n", + "print(logp_f(ip))\n", + "\n", + "x_hyper.set_value(np.array(1 - 0.99999))\n", + "print(logp_f(ip))" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-3.2172261683874277\n", + "-3.107015020305758\n" + ] + } + ], + "execution_count": 7 + }, + { + "cell_type": "code", + "id": "f0b399a92b1b3fae", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:12.610505Z", + "start_time": "2025-01-14T13:03:12.520626Z" + } + }, + "source": [ + "# Avoid mutating variables in place\n", + "um_copy = unconstrained_model.copy()\n", + "\n", + "constrained_points = [] # root variables, created in the loop\n", + "constrained_grads = [] # root variables, created in the loop\n", + "unconstrained_points = []\n", + "unconstrained_grads = []\n", + "sum_log_det_jacobians = 0.0\n", + "for rv in um_copy.free_RVs:\n", + " transform = um_copy.rvs_to_transforms[rv]\n", + " constrained_point = rv.type(name=rv.name)\n", + " constrained_grad = rv.type(name=f\"{rv.name}_grad\")\n", + " unconstrained_point, unconstrained_grad, log_det_jacobian = forward_and_grad(\n", + " transform, constrained_point, constrained_grad, *rv.owner.inputs\n", + " )\n", + " unconstrained_point.name = f\"{rv.name}_unconstrained\"\n", + " unconstrained_grad.name = f\"{rv.name}_grad_unconstrained\"\n", + "\n", + " constrained_points.append(constrained_point)\n", + " constrained_grads.append(constrained_grad)\n", + " unconstrained_points.append(unconstrained_point)\n", + " unconstrained_grads.append(unconstrained_grad)\n", + " sum_log_det_jacobians += log_det_jacobian.sum()\n", + "\n", + "# Replace rvs by the constrained_points\n", + "fgraph = FunctionGraph(\n", + " outputs=[*unconstrained_points, *unconstrained_grads, sum_log_det_jacobians], clone=False\n", + ")\n", + "toposort_replace(fgraph, tuple(zip(um_copy.free_RVs, constrained_points)))\n", + "\n", + "# From constrained space to unconstrained\n", + "pullback_grads_f = pytensor.function(\n", + " [*constrained_points, *constrained_grads],\n", + " fgraph.outputs,\n", + " # mode=get_mode(\"FAST_RUN\").excluding(\"fusion\"),\n", + ")" + ], + "outputs": [], + "execution_count": 8 + }, + { + "cell_type": "code", + "id": "7258a482fc56dc8", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:12.628077Z", + "start_time": "2025-01-14T13:03:12.623455Z" + } + }, + "source": [ + "x_hyper.set_value(np.array(0.5))\n", + "pullback_grads_f(*constrained_point_value.values(), *constrained_point_grad_value.values())" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "[array(-0.3),\n", + " array([-1.20531121, 0. , 1.20531121]),\n", + " array(1.52373625),\n", + " array([ 2.19622022, -0. , -2.19622022]),\n", + " array(0.8602134)]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 9 + }, + { + "cell_type": "code", + "id": "db61d7a3ce7d4ac9", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:12.679786Z", + "start_time": "2025-01-14T13:03:12.675221Z" + } + }, + "source": [ + "x_hyper.set_value(np.array(1.0))\n", + "pullback_grads_f(*constrained_point_value.values(), *constrained_point_grad_value.values())" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "[array(-0.3),\n", + " array([-1.24522667, 0. , 1.24522667]),\n", + " array(1.52373625),\n", + " array([ 2.26895092, -0. , -2.26895092]),\n", + " array(0.95795272)]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 10 + }, + { + "cell_type": "code", + "id": "b9d1d6dc031c7b14", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:12.956020Z", + "start_time": "2025-01-14T13:03:12.785112Z" + } + }, + "source": [ + "loss = pt.add(*[pt.sum((g + p) ** 2) for g, p in zip(unconstrained_grads, unconstrained_points)])\n", + "loss_grad = pt.grad(loss, wrt=x_hyper)\n", + "\n", + "loss_fn = pytensor.function([*constrained_points, *constrained_grads], [loss, loss_grad])" + ], + "outputs": [], + "execution_count": 11 + }, + { + "cell_type": "code", + "id": "cd2959c99811feea", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:12.973412Z", + "start_time": "2025-01-14T13:03:12.969257Z" + } + }, + "source": [ + "x_hyper.set_value(np.array(0.5))\n", + "loss_fn(*constrained_point_value.values(), *constrained_point_grad_value.values())" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "[array(3.46133173), array(0.27690036)]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 12 + }, + { + "cell_type": "code", + "id": "e429821d68921466", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-14T13:03:13.079518Z", + "start_time": "2025-01-14T13:03:13.074282Z" + } + }, + "source": [ + "x_hyper.set_value(x_hyper.get_value() - 0.2769 * 10)\n", + "loss_fn(*constrained_point_value.values(), *constrained_point_grad_value.values())" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "[array(2.92748161), array(0.07287526)]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 13 + }, + { + "cell_type": "markdown", + "id": "bd50ed41d2a830d1", + "metadata": {}, + "source": [ + " fn inv_transform_normalize(\n", + " &mut self,\n", + " params: &Self::TransformParams,\n", + " untransformed_position: &Self::Vector,\n", + " untransofrmed_gradient: &Self::Vector,\n", + " transformed_position: &mut Self::Vector,\n", + " transformed_gradient: &mut Self::Vector,\n", + " ) -> Result;\n", + "\n", + " fn init_from_untransformed_position(\n", + " &mut self,\n", + " params: &Self::TransformParams,\n", + " untransformed_position: &Self::Vector,\n", + " untransformed_gradient: &mut Self::Vector,\n", + " transformed_position: &mut Self::Vector,\n", + " transformed_gradient: &mut Self::Vector,\n", + " ) -> Result<(f64, f64), Self::LogpErr>;\n", + "\n", + " fn init_from_transformed_position(\n", + " &mut self,\n", + " params: &Self::TransformParams,\n", + " untransformed_position: &mut Self::Vector,\n", + " untransformed_gradient: &mut Self::Vector,\n", + " transformed_position: &Self::Vector,\n", + " transformed_gradient: &mut Self::Vector,\n", + " ) -> Result<(f64, f64), Self::LogpErr>;\n", + "\n", + " fn update_transformation<'a, R: rand::Rng + ?Sized>(\n", + " &'a mut self,\n", + " rng: &mut R,\n", + " untransformed_positions: impl ExactSizeIterator,\n", + " untransformed_gradients: impl ExactSizeIterator,\n", + " untransformed_logps: impl ExactSizeIterator,\n", + " params: &'a mut Self::TransformParams,\n", + " ) -> Result<(), Self::LogpErr>;\n", + "\n", + " fn new_transformation(\n", + " &mut self,\n", + " rng: &mut R,\n", + " untransformed_position: &Self::Vector,\n", + " untransfogmed_gradient: &Self::Vector,\n", + " chain: u64,\n", + " ) -> Result;\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index 1b5d4cd817..3e76b86cc9 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -162,6 +162,7 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None: # This means that we need to replace all instance of the old value variable # with "inversely/un-" transformed versions of itself. replacements = {} + new_surprising_vars = set() for valued_node, transformed_rv, transform in zip( valued_nodes, transformed_rv_node.outputs, transforms ): @@ -172,6 +173,9 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None: transformed_val = value else: + if hasattr(transform, "_trafo_params"): + new_surprising_vars.update(transform._trafo_params) + transformed_val = transformed_value( transform.backward(value, *rv.owner.inputs), value, @@ -184,6 +188,10 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None: replacements[val_rv] = valued_rv(transformed_rv, transformed_val) + for new_surprising_var in new_surprising_vars: + # print("Importing hyper parameter: ", new_surprising_var) + fgraph.import_var(new_surprising_var, import_missing=True) + return replacements diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f665d5931c..b412576a4c 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1128,7 +1128,7 @@ def _nested_toposort_index(var, fgraph_toposort) -> tuple[int]: if not var.owner: return (-1,) - index = fgraph_toposort[var.owner] + index = fgraph_toposort.get(var.owner, -1) # Recurse into OpFromGraphs # TODO: Could also recurse into Scans