Skip to content

Commit

Permalink
Merge pull request #64 from thomaswmorris/discrete
Browse files Browse the repository at this point in the history
Add support for discrete DOFs
  • Loading branch information
mrakitin authored Apr 17, 2024
2 parents 372292a + fd7c643 commit 5e0fc07
Show file tree
Hide file tree
Showing 16 changed files with 422 additions and 123 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
blop/_version.py
# setuptools_scm
src/*/_version.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
6 changes: 3 additions & 3 deletions docs/source/agent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ The blop ``Agent`` takes care of the entire optimization loop, from data acquisi
from blop import DOF, Objective, Agent
dofs = [
DOF(name="x1", description="the first DOF", search_bounds=(-10, 10))
DOF(name="x2", description="another DOF", search_bounds=(-5, 5))
DOF(name="x3", description="ayet nother DOF", search_bounds=(0, 1))
DOF(name="x1", description="the first DOF", search_domain=(-10, 10))
DOF(name="x2", description="another DOF", search_domain=(-5, 5))
DOF(name="x3", description="yet another DOF", search_domain=(0, 1))
]
objective = [
Expand Down
31 changes: 27 additions & 4 deletions docs/source/dofs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ A degree of freedom is a variable that affects our optimization objective. We ca
from blop import DOF
dof = DOF(name="x1", description="my first DOF", search_bounds=(lower, upper))
dof = DOF(name="x1", description="my first DOF", search_domain=(lower, upper))
This will instantiate a bunch of stuff under the hood, so that our agent knows how to move things and where to search.
Typically, this will correspond to a real, physical device available in Python. In that case, we can pass the DOF an ophyd device in place of a name
Expand All @@ -16,7 +16,7 @@ Typically, this will correspond to a real, physical device available in Python.
from blop import DOF
dof = DOF(device=my_ophyd_device, description="a real piece of hardware", search_bounds=(lower, upper))
dof = DOF(device=my_ophyd_device, description="a real piece of hardware", search_domain=(lower, upper))
In this case, the agent will control the device as it sees fit, moving it between the search bounds.

Expand All @@ -27,7 +27,30 @@ In this case, we can define a read-only DOF as
from blop import DOF
dof = DOF(device=a_read_only_ophyd_device, description="a thermometer or something", read_only=True, trust_bounds=(lower, upper))
dof = DOF(device=a_read_only_ophyd_device, description="a thermometer or something", read_only=True, trust_domain=(lower, upper))
and the agent will use the received values to model its objective, but won't try to move it.
We can also pass a set of ``trust_bounds``, so that our agent will ignore experiments where the DOF value jumps outside of the interval.
We can also pass a set of ``trust_domain``, so that our agent will ignore experiments where the DOF value jumps outside of the interval.


Discrete degrees of freedom
---------------------------

In addition to degrees of freedom that vary continuously between a lower and upper bound, we can define discrete degrees of freedom.
One kind is a binary degree of freedom, where the input can take one of two values, e.g.

.. code-block:: python
discrete_dof = DOF(name="x1", description="A discrete DOF", type="discrete", search_domain={"in", "out"})
Another is an ordinal degree of freedom, which takes more than two discrete values but has some ordering, e.g.

.. code-block:: python
ordinal_dof = DOF(name="x1", description="An ordinal DOF", type="ordinal", search_domain={"low", "medium", "high"})
The last is a categorical degree of freedom, which can take many different discrete values with no ordering, e.g.

.. code-block:: python
categorical_dof = DOF(name="x1", description="A categorical DOF", type="categorical", search_domain={"banana", "mango", "papaya"})
4 changes: 2 additions & 2 deletions docs/source/tutorials/himmelblau.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
"from blop import DOF\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x2\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x1\", search_domain=(-6, 6)),\n",
" DOF(name=\"x2\", search_domain=(-6, 6)),\n",
"]"
]
},
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/hyperparameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@
"from blop import DOF, Objective, Agent\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x2\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x1\", search_domain=(-6, 6)),\n",
" DOF(name=\"x2\", search_domain=(-6, 6)),\n",
"]\n",
"\n",
"objectives = [\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/passive-dofs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
"\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", search_bounds=(-5.0, 5.0)),\n",
" DOF(name=\"x2\", search_bounds=(-5.0, 5.0)),\n",
" DOF(name=\"x3\", search_bounds=(-5.0, 5.0), active=False),\n",
" DOF(name=\"x1\", search_domain=(-5.0, 5.0)),\n",
" DOF(name=\"x2\", search_domain=(-5.0, 5.0)),\n",
" DOF(name=\"x3\", search_domain=(-5.0, 5.0), active=False),\n",
" DOF(device=BrownianMotion(name=\"brownian1\"), read_only=True),\n",
" DOF(device=BrownianMotion(name=\"brownian2\"), read_only=True, active=False),\n",
"]\n",
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ napari = [
"napari"
]

gui = [
"nicegui"
]

dev = [
"black",
"pytest-codecov",
Expand Down
194 changes: 194 additions & 0 deletions scripts/gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import asyncio

import databroker
import matplotlib as mpl
import numpy as np
from bluesky.callbacks import best_effort
from bluesky.run_engine import RunEngine
from databroker import Broker
from nicegui import ui

from blop import DOF, Agent, Objective
from blop.utils import functions

# MongoDB backend:
db = Broker.named("temp") # mongodb backend
try:
databroker.assets.utils.install_sentinels(db.reg.config, version=1)
except Exception:
pass

loop = asyncio.new_event_loop()
loop.set_debug(True)
RE = RunEngine({}, loop=loop)
RE.subscribe(db.insert)

bec = best_effort.BestEffortCallback()
RE.subscribe(bec)

bec.disable_baseline()
bec.disable_heading()
bec.disable_table()
bec.disable_plots()


dofs = [
DOF(name="x1", description="x1", search_domain=(-5.0, 5.0)),
DOF(name="x2", description="x2", search_domain=(-5.0, 5.0)),
]

objectives = [Objective(name="himmelblau", target="min")]

agent = Agent(
dofs=dofs,
objectives=objectives,
digestion=functions.himmelblau_digestion,
db=db,
verbose=True,
tolerate_acquisition_errors=False,
)

agent.acqf_index = 0

agent.acqf_number = 2


with ui.pyplot(figsize=(10, 4), dpi=160) as obj_plt:
extent = [*agent.dofs[0].search_domain, *agent.dofs[1].search_domain]

ax1 = obj_plt.fig.add_subplot(131)
ax1.set_title("Samples")
im1 = ax1.scatter([], [], cmap="magma")

ax2 = obj_plt.fig.add_subplot(132, sharex=ax1, sharey=ax1)
ax2.set_title("Posterior mean")
im2 = ax2.imshow(np.random.standard_normal(size=(32, 32)), extent=extent, cmap="magma")

ax3 = obj_plt.fig.add_subplot(133, sharex=ax1, sharey=ax1)
ax3.set_title("Posterior error")
im3 = ax3.imshow(np.random.standard_normal(size=(32, 32)), extent=extent, cmap="magma")

data_cbar = obj_plt.fig.colorbar(mappable=im1, ax=[ax1, ax2], location="bottom", aspect=32)
err_cbar = obj_plt.fig.colorbar(mappable=im3, ax=[ax3], location="bottom", aspect=16)

for ax in [ax1, ax2, ax3]:
ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)


acqf_configs = {
0: {"name": "qr", "long_name": r"quasi-random sampling"},
1: {"name": "qei", "long_name": r"$q$-expected improvement"},
2: {"name": "qpi", "long_name": r"$q$-probability of improvement"},
3: {"name": "qucb", "long_name": r"$q$-upper confidence bound"},
}

with ui.pyplot(figsize=(10, 3), dpi=160) as acq_plt:
extent = [*agent.dofs[0].search_domain, *agent.dofs[1].search_domain]

acqf_plt_objs = {}

for iax, config in acqf_configs.items():
if iax == 0:
continue

acqf = config["name"]

acqf_plt_objs[acqf] = {}

acqf_plt_objs[acqf]["ax"] = ax = acq_plt.fig.add_subplot(1, len(acqf_configs) - 1, iax)

ax.set_title(config["long_name"])
acqf_plt_objs[acqf]["im"] = ax.imshow([[]], extent=extent, cmap="gray_r")
acqf_plt_objs[acqf]["hist"] = ax.scatter([], [])
acqf_plt_objs[acqf]["best"] = ax.scatter([], [])

ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)


acqf_button_options = {index: config["name"] for index, config in acqf_configs.items()}

v = ui.checkbox("visible", value=True)
with ui.column().bind_visibility_from(v, "value"):
ui.toggle(acqf_button_options).bind_value(agent, "acqf_index")
ui.number().bind_value(agent, "acqf_number")


def reset():
agent.reset()

print(agent.table)


def learn():
acqf_config = acqf_configs[agent.acqf_index]

acqf = acqf_config["name"]

n = int(agent.acqf_number) if acqf != "qr" else 16

ui.notify(f"sampling {n} points with acquisition function \"{acqf_config['long_name']}\"")

RE(agent.learn(acqf, n=n))

with obj_plt:
obj = agent.objectives[0]

x_samples = agent.train_inputs().detach().numpy()
y_samples = agent.train_targets(obj.name).detach().numpy()[..., 0]

x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
p = obj.model.posterior(x)

m = p.mean.squeeze(-1, -2).detach().numpy()
e = p.variance.sqrt().squeeze(-1, -2).detach().numpy()

im1.set_offsets(x_samples)
im1.set_array(y_samples)
im1.set_cmap("magma")

im2.set_data(m.T[::-1])
im3.set_data(e.T[::-1])

obj_norm = mpl.colors.Normalize(vmin=np.nanmin(y_samples), vmax=np.nanmax(y_samples))
err_norm = mpl.colors.LogNorm(vmin=np.nanmin(e), vmax=np.nanmax(e))

im1.set_norm(obj_norm)
im2.set_norm(obj_norm)
im3.set_norm(err_norm)

for ax in [ax1, ax2, ax3]:
ax.set_xlim(*agent.dofs[0].search_domain)
ax.set_ylim(*agent.dofs[1].search_domain)

with acq_plt:
x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
x_samples = agent.train_inputs().detach().numpy()

for acqf in acqf_plt_objs.keys():
ax = acqf_plt_objs[acqf]["ax"]

acqf_obj = getattr(agent, acqf)(x).detach().numpy()

acqf_norm = mpl.colors.Normalize(vmin=np.nanmin(acqf_obj), vmax=np.nanmax(acqf_obj))
acqf_plt_objs[acqf]["im"].set_data(acqf_obj.T[::-1])
acqf_plt_objs[acqf]["im"].set_norm(acqf_norm)

res = agent.ask(acqf, n=int(agent.acqf_number))

acqf_plt_objs[acqf]["hist"].remove()
acqf_plt_objs[acqf]["hist"] = ax.scatter(*x_samples.T, ec="b", fc="none", marker="o")

acqf_plt_objs[acqf]["best"].remove()
acqf_plt_objs[acqf]["best"] = ax.scatter(*res["points"].T, c="r", marker="x", s=64)

ax.set_xlim(*agent.dofs[0].search_domain)
ax.set_ylim(*agent.dofs[1].search_domain)


ui.button("Learn", on_click=learn)

ui.button("Reset", on_click=reset)

ui.run(port=8004)
17 changes: 0 additions & 17 deletions src/blop/_version.py

This file was deleted.

Loading

0 comments on commit 5e0fc07

Please sign in to comment.