Skip to content

Commit d7e592f

Browse files
authored
feat: many well demo (#85)
* UPDATE README.md * Created using Colaboratory * typo * add backwards compatability to checkpoints * add notebook * Created using Colaboratory * point to main branch * Add Colab link to README * Add Colab link to README
1 parent 710261a commit d7e592f

File tree

3 files changed

+208
-4
lines changed

3 files changed

+208
-4
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ We provide a [colab notebook](experiments/many_well/fab_many_well.ipynb) compari
6767
6 dimensional Many Well problem, where the difference between the two methods is apparent after a
6868
short (<5 min) training period. This experiment can be run locally on a laptop using just CPU.
6969

70+
Additionally, we provide the colab notebook
71+
<a href="https://colab.research.google.com/github/lollcat/fab-torch/blob/master/demo/many_well.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
72+
which demos inference with the flow trained with FAB (+prioritised buffer) on the 32 dim Many Well problem.
73+
7074
To run the experiment for the FAB with a prioritised replay buffer (for the first seed) on the
7175
32 dimensional Many Well problem, use the following command:
7276
```

demo/many_well.ipynb

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"source": [
6+
"<a href=\"https://colab.research.google.com/github/lollcat/fab-torch/blob/master/demo/many_well.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
7+
],
8+
"metadata": {
9+
"collapsed": false
10+
},
11+
"id": "a3e5dcf4be772f9b"
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"source": [
16+
"# Install fab-torch repo"
17+
],
18+
"metadata": {
19+
"id": "-2z7-wbmQgVS"
20+
},
21+
"id": "-2z7-wbmQgVS"
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "94b5026c-6d96-464c-894e-c3d2be6ec58b",
27+
"metadata": {
28+
"id": "94b5026c-6d96-464c-894e-c3d2be6ec58b"
29+
},
30+
"outputs": [],
31+
"source": [
32+
"# If using colab then run this cell.\n",
33+
"!git clone https://github.com/lollcat/fab-torch\n",
34+
"\n",
35+
"import os\n",
36+
"os.chdir(\"fab-torch\")\n",
37+
"\n",
38+
"!pip install --upgrade ."
39+
]
40+
},
41+
{
42+
"cell_type": "markdown",
43+
"source": [
44+
"# Download weights from huggingface and run example of inference\n",
45+
"We can just use CPU as the model is not that expensive."
46+
],
47+
"metadata": {
48+
"id": "xy2GWTB7QlxO"
49+
},
50+
"id": "xy2GWTB7QlxO"
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": null,
55+
"id": "d861b2f8-00be-4e14-998c-348ec89d1c89",
56+
"metadata": {
57+
"id": "d861b2f8-00be-4e14-998c-348ec89d1c89"
58+
},
59+
"outputs": [],
60+
"source": [
61+
"# Restart after install, then run the below code\n",
62+
"import os\n",
63+
"import urllib\n",
64+
"\n",
65+
"import matplotlib.pyplot as plt\n",
66+
"from matplotlib import rc\n",
67+
"import matplotlib as mpl\n",
68+
"from hydra import compose, initialize\n",
69+
"import torch\n",
70+
"\n",
71+
"from fab.utils.plotting import plot_contours, plot_marginal_pair\n",
72+
"from fab.target_distributions.many_well import ManyWellEnergy\n",
73+
"from experiments.setup_run import setup_model\n",
74+
"from experiments.many_well.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"id": "66a8cf07-5d35-4368-a320-fc63e9842d7d",
81+
"metadata": {
82+
"id": "66a8cf07-5d35-4368-a320-fc63e9842d7d"
83+
},
84+
"outputs": [],
85+
"source": [
86+
"with initialize(version_base=None, config_path=\"fab-torch/experiments/config/\", job_name=\"colab_app\"):\n",
87+
" cfg = compose(config_name=f\"many_well\")"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c",
94+
"metadata": {
95+
"id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c"
96+
},
97+
"outputs": [],
98+
"source": [
99+
"target = ManyWellEnergy(cfg.target.dim, a=-0.5, b=-6, use_gpu=False)\n",
100+
"model = setup_model(cfg, target)"
101+
]
102+
},
103+
{
104+
"cell_type": "code",
105+
"execution_count": null,
106+
"id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95",
107+
"metadata": {
108+
"id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95"
109+
},
110+
"outputs": [],
111+
"source": [
112+
"# Download weights from huggingface, and load them into the model\n",
113+
"urllib.request.urlretrieve('https://huggingface.co/VincentStimper/fab/resolve/main/many_well/model.pt', 'model.pt')\n",
114+
"model.load(\"model.pt\", map_location=\"cpu\")"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": null,
120+
"id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed",
121+
"metadata": {
122+
"id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed"
123+
},
124+
"outputs": [],
125+
"source": [
126+
"# Sample from the model\n",
127+
"n_samples: int = 200\n",
128+
"samples_flow = model.flow.sample((n_samples,)).detach()"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": null,
134+
"id": "f338bd95-3ea2-4df7-a700-cb8529a26914",
135+
"metadata": {
136+
"id": "f338bd95-3ea2-4df7-a700-cb8529a26914"
137+
},
138+
"outputs": [],
139+
"source": [
140+
"# Visualise samples\n",
141+
"alpha = 0.3\n",
142+
"plotting_bounds = (-3, 3)\n",
143+
"dim = cfg.target.dim\n",
144+
"fig, axs = plt.subplots(2, 2, sharex=\"row\", sharey=\"row\")\n",
145+
"\n",
146+
"for i in range(2):\n",
147+
" for j in range(2):\n",
148+
" target_log_prob = get_target_log_prob_marginal_pair(target.log_prob, i, j + 2, dim)\n",
149+
" plot_contours(target_log_prob, bounds=plotting_bounds, ax=axs[i, j],\n",
150+
" n_contour_levels=20, grid_width_n_points=100)\n",
151+
" plot_marginal_pair(samples_flow, marginal_dims=(i, j+2),\n",
152+
" ax=axs[i, j], bounds=plotting_bounds, alpha=alpha)\n",
153+
"\n",
154+
"\n",
155+
" if j == 0:\n",
156+
" axs[i, j].set_ylabel(f\"$x_{i + 1}$\")\n",
157+
" if i == 1:\n",
158+
" axs[i, j].set_xlabel(f\"$x_{j + 1 + 2}$\")"
159+
]
160+
},
161+
{
162+
"cell_type": "code",
163+
"source": [],
164+
"metadata": {
165+
"id": "MBK1xcr8UDui"
166+
},
167+
"id": "MBK1xcr8UDui",
168+
"execution_count": null,
169+
"outputs": []
170+
}
171+
],
172+
"metadata": {
173+
"kernelspec": {
174+
"display_name": "Python 3 (ipykernel)",
175+
"language": "python",
176+
"name": "python3"
177+
},
178+
"language_info": {
179+
"codemirror_mode": {
180+
"name": "ipython",
181+
"version": 3
182+
},
183+
"file_extension": ".py",
184+
"mimetype": "text/x-python",
185+
"name": "python",
186+
"nbconvert_exporter": "python",
187+
"pygments_lexer": "ipython3",
188+
"version": "3.9.18"
189+
},
190+
"colab": {
191+
"provenance": [],
192+
"include_colab_link": true
193+
}
194+
},
195+
"nbformat": 4,
196+
"nbformat_minor": 5
197+
}

fab/core.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,13 @@ def load(self,
235235
checkpoint = torch.load(path, map_location=map_location)
236236
try:
237237
self.flow.load_state_dict(checkpoint['flow'])
238-
except RuntimeError:
239-
# If flow is incorretly loaded then this will mess up evaluation, so raise Error.
240-
raise RuntimeError('Flow could not be loaded. '
241-
'Perhaps there is a mismatch in the architectures.')
238+
except:
239+
try:
240+
self.flow._nf_model.load_state_dict(checkpoint['flow'])
241+
except RuntimeError:
242+
# If flow is incorretly loaded then this will mess up evaluation, so raise Error.
243+
raise RuntimeError('Flow could not be loaded. '
244+
'Perhaps there is a mismatch in the architectures.')
242245
try:
243246
self.transition_operator.load_state_dict(checkpoint['trans_op'])
244247
except RuntimeError:

0 commit comments

Comments
 (0)