-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* UPDATE README.md * Created using Colaboratory * typo * add backwards compatability to checkpoints * add notebook * Created using Colaboratory * point to main branch * Add Colab link to README * Add Colab link to README
- Loading branch information
Showing
3 changed files
with
208 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/lollcat/fab-torch/blob/master/demo/many_well.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "a3e5dcf4be772f9b" | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Install fab-torch repo" | ||
], | ||
"metadata": { | ||
"id": "-2z7-wbmQgVS" | ||
}, | ||
"id": "-2z7-wbmQgVS" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "94b5026c-6d96-464c-894e-c3d2be6ec58b", | ||
"metadata": { | ||
"id": "94b5026c-6d96-464c-894e-c3d2be6ec58b" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# If using colab then run this cell.\n", | ||
"!git clone https://github.com/lollcat/fab-torch\n", | ||
"\n", | ||
"import os\n", | ||
"os.chdir(\"fab-torch\")\n", | ||
"\n", | ||
"!pip install --upgrade ." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Download weights from huggingface and run example of inference\n", | ||
"We can just use CPU as the model is not that expensive." | ||
], | ||
"metadata": { | ||
"id": "xy2GWTB7QlxO" | ||
}, | ||
"id": "xy2GWTB7QlxO" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d861b2f8-00be-4e14-998c-348ec89d1c89", | ||
"metadata": { | ||
"id": "d861b2f8-00be-4e14-998c-348ec89d1c89" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Restart after install, then run the below code\n", | ||
"import os\n", | ||
"import urllib\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from matplotlib import rc\n", | ||
"import matplotlib as mpl\n", | ||
"from hydra import compose, initialize\n", | ||
"import torch\n", | ||
"\n", | ||
"from fab.utils.plotting import plot_contours, plot_marginal_pair\n", | ||
"from fab.target_distributions.many_well import ManyWellEnergy\n", | ||
"from experiments.setup_run import setup_model\n", | ||
"from experiments.many_well.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "66a8cf07-5d35-4368-a320-fc63e9842d7d", | ||
"metadata": { | ||
"id": "66a8cf07-5d35-4368-a320-fc63e9842d7d" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"with initialize(version_base=None, config_path=\"fab-torch/experiments/config/\", job_name=\"colab_app\"):\n", | ||
" cfg = compose(config_name=f\"many_well\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c", | ||
"metadata": { | ||
"id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"target = ManyWellEnergy(cfg.target.dim, a=-0.5, b=-6, use_gpu=False)\n", | ||
"model = setup_model(cfg, target)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95", | ||
"metadata": { | ||
"id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Download weights from huggingface, and load them into the model\n", | ||
"urllib.request.urlretrieve('https://huggingface.co/VincentStimper/fab/resolve/main/many_well/model.pt', 'model.pt')\n", | ||
"model.load(\"model.pt\", map_location=\"cpu\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed", | ||
"metadata": { | ||
"id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Sample from the model\n", | ||
"n_samples: int = 200\n", | ||
"samples_flow = model.flow.sample((n_samples,)).detach()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f338bd95-3ea2-4df7-a700-cb8529a26914", | ||
"metadata": { | ||
"id": "f338bd95-3ea2-4df7-a700-cb8529a26914" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Visualise samples\n", | ||
"alpha = 0.3\n", | ||
"plotting_bounds = (-3, 3)\n", | ||
"dim = cfg.target.dim\n", | ||
"fig, axs = plt.subplots(2, 2, sharex=\"row\", sharey=\"row\")\n", | ||
"\n", | ||
"for i in range(2):\n", | ||
" for j in range(2):\n", | ||
" target_log_prob = get_target_log_prob_marginal_pair(target.log_prob, i, j + 2, dim)\n", | ||
" plot_contours(target_log_prob, bounds=plotting_bounds, ax=axs[i, j],\n", | ||
" n_contour_levels=20, grid_width_n_points=100)\n", | ||
" plot_marginal_pair(samples_flow, marginal_dims=(i, j+2),\n", | ||
" ax=axs[i, j], bounds=plotting_bounds, alpha=alpha)\n", | ||
"\n", | ||
"\n", | ||
" if j == 0:\n", | ||
" axs[i, j].set_ylabel(f\"$x_{i + 1}$\")\n", | ||
" if i == 1:\n", | ||
" axs[i, j].set_xlabel(f\"$x_{j + 1 + 2}$\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [], | ||
"metadata": { | ||
"id": "MBK1xcr8UDui" | ||
}, | ||
"id": "MBK1xcr8UDui", | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
}, | ||
"colab": { | ||
"provenance": [], | ||
"include_colab_link": true | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters