diff --git a/README.md b/README.md index 6c3a7fe..5624008 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,10 @@ We provide a [colab notebook](experiments/many_well/fab_many_well.ipynb) compari 6 dimensional Many Well problem, where the difference between the two methods is apparent after a short (<5 min) training period. This experiment can be run locally on a laptop using just CPU. +Additionally, we provide the colab notebook +Open In Colab +which demos inference with the flow trained with FAB (+prioritised buffer) on the 32 dim Many Well problem. + To run the experiment for the FAB with a prioritised replay buffer (for the first seed) on the 32 dimensional Many Well problem, use the following command: ``` diff --git a/demo/many_well.ipynb b/demo/many_well.ipynb new file mode 100644 index 0000000..a584dbd --- /dev/null +++ b/demo/many_well.ipynb @@ -0,0 +1,197 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "\"Open" + ], + "metadata": { + "collapsed": false + }, + "id": "a3e5dcf4be772f9b" + }, + { + "cell_type": "markdown", + "source": [ + "# Install fab-torch repo" + ], + "metadata": { + "id": "-2z7-wbmQgVS" + }, + "id": "-2z7-wbmQgVS" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94b5026c-6d96-464c-894e-c3d2be6ec58b", + "metadata": { + "id": "94b5026c-6d96-464c-894e-c3d2be6ec58b" + }, + "outputs": [], + "source": [ + "# If using colab then run this cell.\n", + "!git clone https://github.com/lollcat/fab-torch\n", + "\n", + "import os\n", + "os.chdir(\"fab-torch\")\n", + "\n", + "!pip install --upgrade ." + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Download weights from huggingface and run example of inference\n", + "We can just use CPU as the model is not that expensive." + ], + "metadata": { + "id": "xy2GWTB7QlxO" + }, + "id": "xy2GWTB7QlxO" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d861b2f8-00be-4e14-998c-348ec89d1c89", + "metadata": { + "id": "d861b2f8-00be-4e14-998c-348ec89d1c89" + }, + "outputs": [], + "source": [ + "# Restart after install, then run the below code\n", + "import os\n", + "import urllib\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import rc\n", + "import matplotlib as mpl\n", + "from hydra import compose, initialize\n", + "import torch\n", + "\n", + "from fab.utils.plotting import plot_contours, plot_marginal_pair\n", + "from fab.target_distributions.many_well import ManyWellEnergy\n", + "from experiments.setup_run import setup_model\n", + "from experiments.many_well.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66a8cf07-5d35-4368-a320-fc63e9842d7d", + "metadata": { + "id": "66a8cf07-5d35-4368-a320-fc63e9842d7d" + }, + "outputs": [], + "source": [ + "with initialize(version_base=None, config_path=\"fab-torch/experiments/config/\", job_name=\"colab_app\"):\n", + " cfg = compose(config_name=f\"many_well\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c", + "metadata": { + "id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c" + }, + "outputs": [], + "source": [ + "target = ManyWellEnergy(cfg.target.dim, a=-0.5, b=-6, use_gpu=False)\n", + "model = setup_model(cfg, target)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95", + "metadata": { + "id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95" + }, + "outputs": [], + "source": [ + "# Download weights from huggingface, and load them into the model\n", + "urllib.request.urlretrieve('https://huggingface.co/VincentStimper/fab/resolve/main/many_well/model.pt', 'model.pt')\n", + "model.load(\"model.pt\", map_location=\"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed", + "metadata": { + "id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed" + }, + "outputs": [], + "source": [ + "# Sample from the model\n", + "n_samples: int = 200\n", + "samples_flow = model.flow.sample((n_samples,)).detach()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f338bd95-3ea2-4df7-a700-cb8529a26914", + "metadata": { + "id": "f338bd95-3ea2-4df7-a700-cb8529a26914" + }, + "outputs": [], + "source": [ + "# Visualise samples\n", + "alpha = 0.3\n", + "plotting_bounds = (-3, 3)\n", + "dim = cfg.target.dim\n", + "fig, axs = plt.subplots(2, 2, sharex=\"row\", sharey=\"row\")\n", + "\n", + "for i in range(2):\n", + " for j in range(2):\n", + " target_log_prob = get_target_log_prob_marginal_pair(target.log_prob, i, j + 2, dim)\n", + " plot_contours(target_log_prob, bounds=plotting_bounds, ax=axs[i, j],\n", + " n_contour_levels=20, grid_width_n_points=100)\n", + " plot_marginal_pair(samples_flow, marginal_dims=(i, j+2),\n", + " ax=axs[i, j], bounds=plotting_bounds, alpha=alpha)\n", + "\n", + "\n", + " if j == 0:\n", + " axs[i, j].set_ylabel(f\"$x_{i + 1}$\")\n", + " if i == 1:\n", + " axs[i, j].set_xlabel(f\"$x_{j + 1 + 2}$\")" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "MBK1xcr8UDui" + }, + "id": "MBK1xcr8UDui", + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + }, + "colab": { + "provenance": [], + "include_colab_link": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/fab/core.py b/fab/core.py index c5f2cc1..9765d06 100644 --- a/fab/core.py +++ b/fab/core.py @@ -235,10 +235,13 @@ def load(self, checkpoint = torch.load(path, map_location=map_location) try: self.flow.load_state_dict(checkpoint['flow']) - except RuntimeError: - # If flow is incorretly loaded then this will mess up evaluation, so raise Error. - raise RuntimeError('Flow could not be loaded. ' - 'Perhaps there is a mismatch in the architectures.') + except: + try: + self.flow._nf_model.load_state_dict(checkpoint['flow']) + except RuntimeError: + # If flow is incorretly loaded then this will mess up evaluation, so raise Error. + raise RuntimeError('Flow could not be loaded. ' + 'Perhaps there is a mismatch in the architectures.') try: self.transition_operator.load_state_dict(checkpoint['trans_op']) except RuntimeError: