Skip to content

Commit

Permalink
feat: many well demo (#85)
Browse files Browse the repository at this point in the history
* UPDATE README.md

* Created using Colaboratory

* typo

* add backwards compatability to checkpoints

* add notebook

* Created using Colaboratory

* point to main branch

* Add Colab link to README

* Add Colab link to README
  • Loading branch information
lollcat authored Jan 30, 2024
1 parent 710261a commit d7e592f
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 4 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ We provide a [colab notebook](experiments/many_well/fab_many_well.ipynb) compari
6 dimensional Many Well problem, where the difference between the two methods is apparent after a
short (<5 min) training period. This experiment can be run locally on a laptop using just CPU.

Additionally, we provide the colab notebook
<a href="https://colab.research.google.com/github/lollcat/fab-torch/blob/master/demo/many_well.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
which demos inference with the flow trained with FAB (+prioritised buffer) on the 32 dim Many Well problem.

To run the experiment for the FAB with a prioritised replay buffer (for the first seed) on the
32 dimensional Many Well problem, use the following command:
```
Expand Down
197 changes: 197 additions & 0 deletions demo/many_well.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"<a href=\"https://colab.research.google.com/github/lollcat/fab-torch/blob/master/demo/many_well.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
],
"metadata": {
"collapsed": false
},
"id": "a3e5dcf4be772f9b"
},
{
"cell_type": "markdown",
"source": [
"# Install fab-torch repo"
],
"metadata": {
"id": "-2z7-wbmQgVS"
},
"id": "-2z7-wbmQgVS"
},
{
"cell_type": "code",
"execution_count": null,
"id": "94b5026c-6d96-464c-894e-c3d2be6ec58b",
"metadata": {
"id": "94b5026c-6d96-464c-894e-c3d2be6ec58b"
},
"outputs": [],
"source": [
"# If using colab then run this cell.\n",
"!git clone https://github.com/lollcat/fab-torch\n",
"\n",
"import os\n",
"os.chdir(\"fab-torch\")\n",
"\n",
"!pip install --upgrade ."
]
},
{
"cell_type": "markdown",
"source": [
"# Download weights from huggingface and run example of inference\n",
"We can just use CPU as the model is not that expensive."
],
"metadata": {
"id": "xy2GWTB7QlxO"
},
"id": "xy2GWTB7QlxO"
},
{
"cell_type": "code",
"execution_count": null,
"id": "d861b2f8-00be-4e14-998c-348ec89d1c89",
"metadata": {
"id": "d861b2f8-00be-4e14-998c-348ec89d1c89"
},
"outputs": [],
"source": [
"# Restart after install, then run the below code\n",
"import os\n",
"import urllib\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import rc\n",
"import matplotlib as mpl\n",
"from hydra import compose, initialize\n",
"import torch\n",
"\n",
"from fab.utils.plotting import plot_contours, plot_marginal_pair\n",
"from fab.target_distributions.many_well import ManyWellEnergy\n",
"from experiments.setup_run import setup_model\n",
"from experiments.many_well.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66a8cf07-5d35-4368-a320-fc63e9842d7d",
"metadata": {
"id": "66a8cf07-5d35-4368-a320-fc63e9842d7d"
},
"outputs": [],
"source": [
"with initialize(version_base=None, config_path=\"fab-torch/experiments/config/\", job_name=\"colab_app\"):\n",
" cfg = compose(config_name=f\"many_well\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c",
"metadata": {
"id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c"
},
"outputs": [],
"source": [
"target = ManyWellEnergy(cfg.target.dim, a=-0.5, b=-6, use_gpu=False)\n",
"model = setup_model(cfg, target)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95",
"metadata": {
"id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95"
},
"outputs": [],
"source": [
"# Download weights from huggingface, and load them into the model\n",
"urllib.request.urlretrieve('https://huggingface.co/VincentStimper/fab/resolve/main/many_well/model.pt', 'model.pt')\n",
"model.load(\"model.pt\", map_location=\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed",
"metadata": {
"id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed"
},
"outputs": [],
"source": [
"# Sample from the model\n",
"n_samples: int = 200\n",
"samples_flow = model.flow.sample((n_samples,)).detach()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f338bd95-3ea2-4df7-a700-cb8529a26914",
"metadata": {
"id": "f338bd95-3ea2-4df7-a700-cb8529a26914"
},
"outputs": [],
"source": [
"# Visualise samples\n",
"alpha = 0.3\n",
"plotting_bounds = (-3, 3)\n",
"dim = cfg.target.dim\n",
"fig, axs = plt.subplots(2, 2, sharex=\"row\", sharey=\"row\")\n",
"\n",
"for i in range(2):\n",
" for j in range(2):\n",
" target_log_prob = get_target_log_prob_marginal_pair(target.log_prob, i, j + 2, dim)\n",
" plot_contours(target_log_prob, bounds=plotting_bounds, ax=axs[i, j],\n",
" n_contour_levels=20, grid_width_n_points=100)\n",
" plot_marginal_pair(samples_flow, marginal_dims=(i, j+2),\n",
" ax=axs[i, j], bounds=plotting_bounds, alpha=alpha)\n",
"\n",
"\n",
" if j == 0:\n",
" axs[i, j].set_ylabel(f\"$x_{i + 1}$\")\n",
" if i == 1:\n",
" axs[i, j].set_xlabel(f\"$x_{j + 1 + 2}$\")"
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "MBK1xcr8UDui"
},
"id": "MBK1xcr8UDui",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
11 changes: 7 additions & 4 deletions fab/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,13 @@ def load(self,
checkpoint = torch.load(path, map_location=map_location)
try:
self.flow.load_state_dict(checkpoint['flow'])
except RuntimeError:
# If flow is incorretly loaded then this will mess up evaluation, so raise Error.
raise RuntimeError('Flow could not be loaded. '
'Perhaps there is a mismatch in the architectures.')
except:
try:
self.flow._nf_model.load_state_dict(checkpoint['flow'])
except RuntimeError:
# If flow is incorretly loaded then this will mess up evaluation, so raise Error.
raise RuntimeError('Flow could not be loaded. '
'Perhaps there is a mismatch in the architectures.')
try:
self.transition_operator.load_state_dict(checkpoint['trans_op'])
except RuntimeError:
Expand Down

0 comments on commit d7e592f

Please sign in to comment.