diff --git a/.gitignore b/.gitignore index dbea649..2416f3a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ -blop/_version.py +# setuptools_scm +src/*/_version.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/docs/source/agent.rst b/docs/source/agent.rst index 27ec35d..1034c2f 100644 --- a/docs/source/agent.rst +++ b/docs/source/agent.rst @@ -8,9 +8,9 @@ The blop ``Agent`` takes care of the entire optimization loop, from data acquisi from blop import DOF, Objective, Agent dofs = [ - DOF(name="x1", description="the first DOF", search_bounds=(-10, 10)) - DOF(name="x2", description="another DOF", search_bounds=(-5, 5)) - DOF(name="x3", description="ayet nother DOF", search_bounds=(0, 1)) + DOF(name="x1", description="the first DOF", search_domain=(-10, 10)) + DOF(name="x2", description="another DOF", search_domain=(-5, 5)) + DOF(name="x3", description="yet another DOF", search_domain=(0, 1)) ] objective = [ diff --git a/docs/source/dofs.rst b/docs/source/dofs.rst index 42ec90a..a77981c 100644 --- a/docs/source/dofs.rst +++ b/docs/source/dofs.rst @@ -7,7 +7,7 @@ A degree of freedom is a variable that affects our optimization objective. We ca from blop import DOF - dof = DOF(name="x1", description="my first DOF", search_bounds=(lower, upper)) + dof = DOF(name="x1", description="my first DOF", search_domain=(lower, upper)) This will instantiate a bunch of stuff under the hood, so that our agent knows how to move things and where to search. Typically, this will correspond to a real, physical device available in Python. In that case, we can pass the DOF an ophyd device in place of a name @@ -16,7 +16,7 @@ Typically, this will correspond to a real, physical device available in Python. from blop import DOF - dof = DOF(device=my_ophyd_device, description="a real piece of hardware", search_bounds=(lower, upper)) + dof = DOF(device=my_ophyd_device, description="a real piece of hardware", search_domain=(lower, upper)) In this case, the agent will control the device as it sees fit, moving it between the search bounds. @@ -27,7 +27,30 @@ In this case, we can define a read-only DOF as from blop import DOF - dof = DOF(device=a_read_only_ophyd_device, description="a thermometer or something", read_only=True, trust_bounds=(lower, upper)) + dof = DOF(device=a_read_only_ophyd_device, description="a thermometer or something", read_only=True, trust_domain=(lower, upper)) and the agent will use the received values to model its objective, but won't try to move it. -We can also pass a set of ``trust_bounds``, so that our agent will ignore experiments where the DOF value jumps outside of the interval. +We can also pass a set of ``trust_domain``, so that our agent will ignore experiments where the DOF value jumps outside of the interval. + + +Discrete degrees of freedom +--------------------------- + +In addition to degrees of freedom that vary continuously between a lower and upper bound, we can define discrete degrees of freedom. +One kind is a binary degree of freedom, where the input can take one of two values, e.g. + +.. code-block:: python + + discrete_dof = DOF(name="x1", description="A discrete DOF", type="discrete", search_domain={"in", "out"}) + +Another is an ordinal degree of freedom, which takes more than two discrete values but has some ordering, e.g. + +.. code-block:: python + + ordinal_dof = DOF(name="x1", description="An ordinal DOF", type="ordinal", search_domain={"low", "medium", "high"}) + +The last is a categorical degree of freedom, which can take many different discrete values with no ordering, e.g. + +.. code-block:: python + + categorical_dof = DOF(name="x1", description="A categorical DOF", type="categorical", search_domain={"banana", "mango", "papaya"}) diff --git a/docs/source/tutorials/himmelblau.ipynb b/docs/source/tutorials/himmelblau.ipynb index 8b1a435..7925a58 100644 --- a/docs/source/tutorials/himmelblau.ipynb +++ b/docs/source/tutorials/himmelblau.ipynb @@ -72,8 +72,8 @@ "from blop import DOF\n", "\n", "dofs = [\n", - " DOF(name=\"x1\", search_bounds=(-6, 6)),\n", - " DOF(name=\"x2\", search_bounds=(-6, 6)),\n", + " DOF(name=\"x1\", search_domain=(-6, 6)),\n", + " DOF(name=\"x2\", search_domain=(-6, 6)),\n", "]" ] }, diff --git a/docs/source/tutorials/hyperparameters.ipynb b/docs/source/tutorials/hyperparameters.ipynb index f5f0ae2..9170ffb 100644 --- a/docs/source/tutorials/hyperparameters.ipynb +++ b/docs/source/tutorials/hyperparameters.ipynb @@ -75,8 +75,8 @@ "from blop import DOF, Objective, Agent\n", "\n", "dofs = [\n", - " DOF(name=\"x1\", search_bounds=(-6, 6)),\n", - " DOF(name=\"x2\", search_bounds=(-6, 6)),\n", + " DOF(name=\"x1\", search_domain=(-6, 6)),\n", + " DOF(name=\"x2\", search_domain=(-6, 6)),\n", "]\n", "\n", "objectives = [\n", diff --git a/docs/source/tutorials/passive-dofs.ipynb b/docs/source/tutorials/passive-dofs.ipynb index 42a0948..8e53358 100644 --- a/docs/source/tutorials/passive-dofs.ipynb +++ b/docs/source/tutorials/passive-dofs.ipynb @@ -44,9 +44,9 @@ "\n", "\n", "dofs = [\n", - " DOF(name=\"x1\", search_bounds=(-5.0, 5.0)),\n", - " DOF(name=\"x2\", search_bounds=(-5.0, 5.0)),\n", - " DOF(name=\"x3\", search_bounds=(-5.0, 5.0), active=False),\n", + " DOF(name=\"x1\", search_domain=(-5.0, 5.0)),\n", + " DOF(name=\"x2\", search_domain=(-5.0, 5.0)),\n", + " DOF(name=\"x3\", search_domain=(-5.0, 5.0), active=False),\n", " DOF(device=BrownianMotion(name=\"brownian1\"), read_only=True),\n", " DOF(device=BrownianMotion(name=\"brownian2\"), read_only=True, active=False),\n", "]\n", diff --git a/pyproject.toml b/pyproject.toml index 9b66912..4cd8a5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,10 @@ napari = [ "napari" ] +gui = [ + "nicegui" +] + dev = [ "black", "pytest-codecov", diff --git a/scripts/gui.py b/scripts/gui.py new file mode 100644 index 0000000..9955180 --- /dev/null +++ b/scripts/gui.py @@ -0,0 +1,194 @@ +import asyncio + +import databroker +import matplotlib as mpl +import numpy as np +from bluesky.callbacks import best_effort +from bluesky.run_engine import RunEngine +from databroker import Broker +from nicegui import ui + +from blop import DOF, Agent, Objective +from blop.utils import functions + +# MongoDB backend: +db = Broker.named("temp") # mongodb backend +try: + databroker.assets.utils.install_sentinels(db.reg.config, version=1) +except Exception: + pass + +loop = asyncio.new_event_loop() +loop.set_debug(True) +RE = RunEngine({}, loop=loop) +RE.subscribe(db.insert) + +bec = best_effort.BestEffortCallback() +RE.subscribe(bec) + +bec.disable_baseline() +bec.disable_heading() +bec.disable_table() +bec.disable_plots() + + +dofs = [ + DOF(name="x1", description="x1", search_domain=(-5.0, 5.0)), + DOF(name="x2", description="x2", search_domain=(-5.0, 5.0)), +] + +objectives = [Objective(name="himmelblau", target="min")] + +agent = Agent( + dofs=dofs, + objectives=objectives, + digestion=functions.himmelblau_digestion, + db=db, + verbose=True, + tolerate_acquisition_errors=False, +) + +agent.acqf_index = 0 + +agent.acqf_number = 2 + + +with ui.pyplot(figsize=(10, 4), dpi=160) as obj_plt: + extent = [*agent.dofs[0].search_domain, *agent.dofs[1].search_domain] + + ax1 = obj_plt.fig.add_subplot(131) + ax1.set_title("Samples") + im1 = ax1.scatter([], [], cmap="magma") + + ax2 = obj_plt.fig.add_subplot(132, sharex=ax1, sharey=ax1) + ax2.set_title("Posterior mean") + im2 = ax2.imshow(np.random.standard_normal(size=(32, 32)), extent=extent, cmap="magma") + + ax3 = obj_plt.fig.add_subplot(133, sharex=ax1, sharey=ax1) + ax3.set_title("Posterior error") + im3 = ax3.imshow(np.random.standard_normal(size=(32, 32)), extent=extent, cmap="magma") + + data_cbar = obj_plt.fig.colorbar(mappable=im1, ax=[ax1, ax2], location="bottom", aspect=32) + err_cbar = obj_plt.fig.colorbar(mappable=im3, ax=[ax3], location="bottom", aspect=16) + + for ax in [ax1, ax2, ax3]: + ax.set_xlabel(agent.dofs[0].label) + ax.set_ylabel(agent.dofs[1].label) + + +acqf_configs = { + 0: {"name": "qr", "long_name": r"quasi-random sampling"}, + 1: {"name": "qei", "long_name": r"$q$-expected improvement"}, + 2: {"name": "qpi", "long_name": r"$q$-probability of improvement"}, + 3: {"name": "qucb", "long_name": r"$q$-upper confidence bound"}, +} + +with ui.pyplot(figsize=(10, 3), dpi=160) as acq_plt: + extent = [*agent.dofs[0].search_domain, *agent.dofs[1].search_domain] + + acqf_plt_objs = {} + + for iax, config in acqf_configs.items(): + if iax == 0: + continue + + acqf = config["name"] + + acqf_plt_objs[acqf] = {} + + acqf_plt_objs[acqf]["ax"] = ax = acq_plt.fig.add_subplot(1, len(acqf_configs) - 1, iax) + + ax.set_title(config["long_name"]) + acqf_plt_objs[acqf]["im"] = ax.imshow([[]], extent=extent, cmap="gray_r") + acqf_plt_objs[acqf]["hist"] = ax.scatter([], []) + acqf_plt_objs[acqf]["best"] = ax.scatter([], []) + + ax.set_xlabel(agent.dofs[0].label) + ax.set_ylabel(agent.dofs[1].label) + + +acqf_button_options = {index: config["name"] for index, config in acqf_configs.items()} + +v = ui.checkbox("visible", value=True) +with ui.column().bind_visibility_from(v, "value"): + ui.toggle(acqf_button_options).bind_value(agent, "acqf_index") + ui.number().bind_value(agent, "acqf_number") + + +def reset(): + agent.reset() + + print(agent.table) + + +def learn(): + acqf_config = acqf_configs[agent.acqf_index] + + acqf = acqf_config["name"] + + n = int(agent.acqf_number) if acqf != "qr" else 16 + + ui.notify(f"sampling {n} points with acquisition function \"{acqf_config['long_name']}\"") + + RE(agent.learn(acqf, n=n)) + + with obj_plt: + obj = agent.objectives[0] + + x_samples = agent.train_inputs().detach().numpy() + y_samples = agent.train_targets(obj.name).detach().numpy()[..., 0] + + x = agent.sample(method="grid", n=20000) # (n, n, 1, d) + p = obj.model.posterior(x) + + m = p.mean.squeeze(-1, -2).detach().numpy() + e = p.variance.sqrt().squeeze(-1, -2).detach().numpy() + + im1.set_offsets(x_samples) + im1.set_array(y_samples) + im1.set_cmap("magma") + + im2.set_data(m.T[::-1]) + im3.set_data(e.T[::-1]) + + obj_norm = mpl.colors.Normalize(vmin=np.nanmin(y_samples), vmax=np.nanmax(y_samples)) + err_norm = mpl.colors.LogNorm(vmin=np.nanmin(e), vmax=np.nanmax(e)) + + im1.set_norm(obj_norm) + im2.set_norm(obj_norm) + im3.set_norm(err_norm) + + for ax in [ax1, ax2, ax3]: + ax.set_xlim(*agent.dofs[0].search_domain) + ax.set_ylim(*agent.dofs[1].search_domain) + + with acq_plt: + x = agent.sample(method="grid", n=20000) # (n, n, 1, d) + x_samples = agent.train_inputs().detach().numpy() + + for acqf in acqf_plt_objs.keys(): + ax = acqf_plt_objs[acqf]["ax"] + + acqf_obj = getattr(agent, acqf)(x).detach().numpy() + + acqf_norm = mpl.colors.Normalize(vmin=np.nanmin(acqf_obj), vmax=np.nanmax(acqf_obj)) + acqf_plt_objs[acqf]["im"].set_data(acqf_obj.T[::-1]) + acqf_plt_objs[acqf]["im"].set_norm(acqf_norm) + + res = agent.ask(acqf, n=int(agent.acqf_number)) + + acqf_plt_objs[acqf]["hist"].remove() + acqf_plt_objs[acqf]["hist"] = ax.scatter(*x_samples.T, ec="b", fc="none", marker="o") + + acqf_plt_objs[acqf]["best"].remove() + acqf_plt_objs[acqf]["best"] = ax.scatter(*res["points"].T, c="r", marker="x", s=64) + + ax.set_xlim(*agent.dofs[0].search_domain) + ax.set_ylim(*agent.dofs[1].search_domain) + + +ui.button("Learn", on_click=learn) + +ui.button("Reset", on_click=reset) + +ui.run(port=8004) diff --git a/src/blop/_version.py b/src/blop/_version.py deleted file mode 100644 index 934a424..0000000 --- a/src/blop/_version.py +++ /dev/null @@ -1,17 +0,0 @@ -# file generated by setuptools_scm -# don't change, don't track in version control -TYPE_CHECKING = False -if TYPE_CHECKING: - from typing import Tuple, Union - - VERSION_TUPLE = Tuple[Union[int, str], ...] -else: - VERSION_TUPLE = object - -version: str -__version__: str -__version_tuple__: VERSION_TUPLE -version_tuple: VERSION_TUPLE - -__version__ = version = "0.6.2.dev0" -__version_tuple__ = version_tuple = (0, 6, 2, "dev0") diff --git a/src/blop/agent.py b/src/blop/agent.py index 079d07d..3b9d4c8 100644 --- a/src/blop/agent.py +++ b/src/blop/agent.py @@ -1,4 +1,6 @@ import logging +import os +import pathlib import time as ttime import warnings from collections import OrderedDict @@ -139,6 +141,12 @@ def __init__( self.n_last_trained = 0 + def unpack_run(self): + return + + def measurement_plan(self): + return + @property def active_dofs(self): return self.dofs.subset(active=True) @@ -239,7 +247,7 @@ def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True, upsam candidates, acqf_obj = botorch.optim.optimize_acqf( acq_function=acq_func, - bounds=self._sample_bounds, + bounds=self._sample_domain, q=n, sequential=sequential, num_restarts=NUM_RESTARTS, @@ -383,7 +391,7 @@ def learn( """ if self.sample_center_on_init and not self.initialized: - center_inputs = np.atleast_2d(self.dofs.subset(active=True, read_only=False).search_bounds.mean(axis=1)) + center_inputs = np.atleast_2d(self.dofs.subset(active=True, read_only=False).search_domain.mean(axis=1)) new_table = yield from self.acquire(center_inputs) new_table.loc[:, "acq_func"] = "sample_center_on_init" @@ -497,7 +505,10 @@ def reset(self): self.n_last_trained = 0 def benchmark( - self, output_dir="./", runs=16, n_init=64, learning_kwargs_list=[{"acq_func": "qei", "n": 4, "iterations": 16}] + self, + output_dir="./", + iterations=16, + per_iter_learn_kwargs_list=[{"acq_func": "qr", "n": 32}, {"acq_func": "qei", "n": 4, "iterations": 4}], ): """Iterate over having the agent learn from scratch, and save the results to an output directory. @@ -505,19 +516,19 @@ def benchmark( ---------- output_dir : Where to save the agent output. - runs : int + iterations : int How many benchmarks to run - learning_kwargs_list: - A list of kwargs to pass to the learn method which the agent will run sequentially for each run. + per_iter_learn_kwargs_list: + A list of kwargs to pass to the agent.learn() method that the agent will run sequentially for each iteration. """ - for run in range(runs): + for _ in range(iterations): self.reset() - for kwargs in learning_kwargs_list: + for kwargs in per_iter_learn_kwargs_list: yield from self.learn(**kwargs) - self.save_data(output_dir + f"benchmark-{int(ttime.time())}.h5") + self.save_data(f"{output_dir}/blop_benchmark_{int(ttime.time())}.h5") @property def model(self): @@ -668,17 +679,17 @@ def _latent_dim_tuples(self, obj_index=None): return [tuple(np.where(uinv == i)[0]) for i in range(len(u))] @property - def _sample_bounds(self): - return torch.tensor(self.active_dofs.search_bounds, dtype=torch.double).T + def _sample_domain(self): + return torch.tensor(self.active_dofs.search_domain, dtype=torch.double).T @property def _sample_input_transform(self): tf1 = Log10(indices=list(np.where(self.active_dofs.log)[0])) - transformed_sample_bounds = tf1.transform(self._sample_bounds) + transformed_sample_domain = tf1.transform(self._sample_domain) - offset = transformed_sample_bounds.min(dim=0).values - coefficient = (transformed_sample_bounds.max(dim=0).values - offset).clamp(min=1e-16) + offset = transformed_sample_domain.min(dim=0).values + coefficient = (transformed_sample_domain.max(dim=0).values - offset).clamp(min=1e-16) tf2 = AffineInputTransform(d=len(offset), coefficient=coefficient, offset=offset) @@ -691,7 +702,7 @@ def _model_input_transform(self): For modeling: - Always normalize between min and max values. This is always inside the trust bounds, sometimes smaller. + Always normalize between min and max values. This is always inside the trust domain, sometimes smaller. For sampling: @@ -704,13 +715,15 @@ def _model_input_transform(self): return ChainedInputTransform(tf1=tf1, tf2=tf2) - def save_data(self, filepath="./self_data.h5"): + def save_data(self, path="./data.h5"): """ Save the sampled inputs and targets of the agent to a file, which can be used - to initialize a future self. + to initialize a future agent. """ - self.table.to_hdf(filepath, key="table") + save_dir, _ = os.path.split(path) + pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) + self.table.to_hdf(path, key="table") def forget(self, last=None, index=None, train=True): """ @@ -823,7 +836,7 @@ def train_inputs(self, index=None, **subset_kwargs): inputs = self.table.loc[:, dof.name].values.copy() # check that inputs values are inside acceptable values - valid = (inputs >= dof._trust_bounds[0]) & (inputs <= dof._trust_bounds[1]) + valid = (inputs >= dof._trust_domain[0]) & (inputs <= dof._trust_domain[1]) inputs = np.where(valid, inputs, np.nan) return torch.tensor(inputs, dtype=torch.double).unsqueeze(-1) @@ -838,7 +851,7 @@ def train_targets(self, index=None, **subset_kwargs): targets = self.table.loc[:, obj.name].values.copy() # check that targets values are inside acceptable values - valid = (targets >= obj._trust_bounds[0]) & (targets <= obj._trust_bounds[1]) + valid = (targets >= obj._trust_domain[0]) & (targets <= obj._trust_domain[1]) targets = np.where(valid, targets, np.nan) # transform if needed diff --git a/src/blop/bayesian/plotting.py b/src/blop/bayesian/plotting.py index a1c8917..568ad6e 100644 --- a/src/blop/bayesian/plotting.py +++ b/src/blop/bayesian/plotting.py @@ -50,7 +50,7 @@ def _plot_objs_one_dof(agent, size=16, lw=1e0): alpha=0.5**z, ) - agent.obj_axes[obj_index].set_xlim(*x_dof.search_bounds) + agent.obj_axes[obj_index].set_xlim(*x_dof.search_domain) agent.obj_axes[obj_index].set_xlabel(x_dof.label) agent.obj_axes[obj_index].set_ylabel(obj.label) @@ -179,8 +179,8 @@ def _plot_objs_many_dofs(agent, axes=(0, 1), shading="nearest", cmap=DEFAULT_COL for ax in agent.obj_axes.ravel(): ax.set_xlabel(x_dof.label) ax.set_ylabel(y_dof.label) - ax.set_xlim(*x_dof.search_bounds) - ax.set_ylim(*y_dof.search_bounds) + ax.set_xlim(*x_dof.search_domain) + ax.set_ylim(*y_dof.search_domain) if x_dof.log: ax.set_xscale("log") if y_dof.log: @@ -209,7 +209,7 @@ def _plot_acqf_one_dof(agent, acq_funcs, lw=1e0, **kwargs): agent.acq_axes[iacq_func].plot(test_inputs.squeeze(-2), test_acqf, lw=lw, color=color) - agent.acq_axes[iacq_func].set_xlim(*x_dof.search_bounds) + agent.acq_axes[iacq_func].set_xlim(*x_dof.search_domain) agent.acq_axes[iacq_func].set_xlabel(x_dof.label) agent.acq_axes[iacq_func].set_ylabel(acq_func_meta["name"]) @@ -271,8 +271,8 @@ def _plot_acqf_many_dofs( for ax in agent.acq_axes.ravel(): ax.set_xlabel(x_dof.label) ax.set_ylabel(y_dof.label) - ax.set_xlim(*x_dof.search_bounds) - ax.set_ylim(*y_dof.search_bounds) + ax.set_xlim(*x_dof.search_domain) + ax.set_ylim(*y_dof.search_domain) if x_dof.log: ax.set_xscale("log") if y_dof.log: @@ -290,7 +290,7 @@ def _plot_valid_one_dof(agent, size=16, lw=1e0): agent.valid_ax.scatter(x_values, agent.all_objectives_valid, s=size) agent.valid_ax.plot(test_inputs.squeeze(-2), constraint, lw=lw) - agent.valid_ax.set_xlim(*x_dof.search_bounds) + agent.valid_ax.set_xlim(*x_dof.search_domain) def _plot_valid_many_dofs(agent, axes=[0, 1], shading="nearest", cmap=DEFAULT_COLORMAP, size=16, gridded=None): @@ -335,8 +335,8 @@ def _plot_valid_many_dofs(agent, axes=[0, 1], shading="nearest", cmap=DEFAULT_CO for ax in agent.valid_axes.ravel(): ax.set_xlabel(x_dof.label) ax.set_ylabel(y_dof.label) - ax.set_xlim(*x_dof.search_bounds) - ax.set_ylim(*y_dof.search_bounds) + ax.set_xlim(*x_dof.search_domain) + ax.set_ylim(*y_dof.search_domain) if x_dof.log: ax.set_xscale("log") if y_dof.log: diff --git a/src/blop/dofs.py b/src/blop/dofs.py index 7e2ae6f..a78a8f0 100644 --- a/src/blop/dofs.py +++ b/src/blop/dofs.py @@ -10,9 +10,10 @@ DOF_FIELD_TYPES = { "description": "str", - "readback": "float", - "search_bounds": "object", - "trust_bounds": "object", + "readback": "object", + "type": "str", + "search_domain": "object", + "trust_domain": "object", "units": "str", "active": "bool", "read_only": "bool", @@ -20,6 +21,8 @@ "tags": "object", } +SUPPORTED_DOF_TYPES = ["continuous", "binary", "ordinal", "categorical"] + class ReadOnlyError(Exception): ... @@ -46,13 +49,19 @@ class DOF: name: str The name of the DOF. This is used as a key to index observed data. description: str, optional - A longer name for the DOF. + A longer, more descriptive name for the DOF. + type: str + What kind of DOF it is. A DOF can be: + - Continuous, meaning that it can vary to any point between a lower and upper bound. + - Binary, meaning that it can take one of two values (e.g. [on, off]) + - Ordinal, meaning ordered categories (e.g. [low, medium, high]) + - Categorical, meaning non-ordered categories (e.g. [mango, banana, papaya]) + search_domain: tuple + A tuple of the lower and upper limit of the DOF for the agent to search. + trust_domain: tuple, optional + The agent will reject all data where the DOF value is outside the trust domain. Must be larger than search domain. units: str The units of the DOF (e.g. mm or deg). This is just for plotting and general sanity checking. - search_bounds: tuple - A tuple of the lower and upper limit of the DOF for the agent to search. - trust_bounds: tuple, optional - The agent will reject all data where the DOF value is outside the trust bounds. Must be larger than search bounds. read_only: bool If True, the agent will not try to set the DOF. Must be set to True if the supplied ophyd device is read-only. @@ -69,8 +78,9 @@ class DOF: name: str = None description: str = "" - search_bounds: Tuple[float, float] = None - trust_bounds: Tuple[float, float] = None + type: str = "continuous" + search_domain: Tuple[float, float] = None + trust_domain: Tuple[float, float] = None units: str = "" read_only: bool = False active: bool = True @@ -80,57 +90,80 @@ class DOF: # Some post-processing. This is specific to dataclasses def __post_init__(self): - if self.search_bounds is None: - if not self.read_only: - raise ValueError("You must specify search_bounds if the device is not read-only.") - else: - self.search_bounds = tuple(self.search_bounds) - if len(self.search_bounds) != 2: - raise ValueError("'search_bounds' must be a 2-tuple of floats.") - if self.search_bounds[0] > self.search_bounds[1]: - raise ValueError("The lower search bound must be less than the upper search bound.") - - if self.trust_bounds is not None: - self.trust_bounds = tuple(self.trust_bounds) - if not self.read_only: - if (self.search_bounds[0] < self.trust_bounds[0]) or (self.search_bounds[1] > self.trust_bounds[1]): - raise ValueError("Trust bounds must be larger than search bounds.") + if self.type not in SUPPORTED_DOF_TYPES: + raise ValueError(f"'type' must be one of {SUPPORTED_DOF_TYPES}") if (self.name is None) ^ (self.device is None): if self.name is None: self.name = self.device.name - if self.device is None: - self.device = Signal(name=self.name) else: raise ValueError("DOF() accepts exactly one of either a name or an ophyd device.") + # if our input is continuous + if self.type == "continuous": + if self.search_domain is None: + if not self.read_only: + raise ValueError("You must specify search_domain if the device is not read-only.") + else: + self.search_domain = tuple(self.search_domain) + if len(self.search_domain) != 2: + raise ValueError("'search_domain' must be a 2-tuple of floats.") + if self.search_domain[0] > self.search_domain[1]: + raise ValueError("The lower search bound must be less than the upper search bound.") + + if self.trust_domain is not None: + self.trust_domain = tuple(self.trust_domain) + if not self.read_only: + if (self.search_domain[0] < self.trust_domain[0]) or (self.search_domain[1] > self.trust_domain[1]): + raise ValueError("Trust domain must be larger than search domain.") + + if self.log: + if not self.search_domain[0] > 0: + raise ValueError("Search domain must be strictly positive if log=True.") + + if self.device is None: + center_value = np.mean(np.log(self.search_domain)) if self.log else np.mean(self.search_domain) + self.device = Signal(name=self.name, value=center_value) + + # otherwise it must be discrete + else: + if self.type == "binary": + if self.search_domain is None: + self.search_domain = [False, True] + if len(self.search_domain) != 2: + raise ValueError("A binary DOF must have a domain of 2.") + else: + if self.search_domain is None: + raise ValueError("Discrete domain must be supplied for ordinal and categorical degrees of freedom.") + + self.search_domain = set(self.search_domain) + + self.device = Signal(name=self.name, value=list(self.search_domain)[0]) + if not self.read_only: # check that the device has a put method if isinstance(self.device, SignalRO): raise ValueError("You must specify read_only=True for a read-only device.") - if self.log: - if not self.search_bounds[0] > 0: - raise ValueError("Search bounds must be strictly positive if log=True.") - # all dof degrees of freedom are hinted self.device.kind = "hinted" @property - def _search_bounds(self): + def _search_domain(self): if self.read_only: _readback = self.readback return (_readback, _readback) - return self.search_bounds + return self.search_domain @property - def _trust_bounds(self): - if self.trust_bounds is None: + def _trust_domain(self): + if self.trust_domain is None: return (0, np.inf) if self.log else (-np.inf, np.inf) - return self.trust_bounds + return self.trust_domain @property def readback(self): + # there is probably a better way to do this return self.device.read()[self.device.name]["value"] @property @@ -181,10 +214,17 @@ def __len__(self): return len(self.dofs) def __repr__(self): - return self.summary.__repr__() + return self.summary.T.__repr__() + + def _repr_html_(self): + return self.summary.T._repr_html_() - def __repr_html__(self): - return self.summary.__repr_html__() + @property + def readback(self): + """ + Return the readback from each DOF as a list. It is a list because they might be different types. + """ + return [dof.readback for dof in self.dofs] @property def summary(self) -> pd.DataFrame: @@ -208,18 +248,18 @@ def devices(self) -> list: return [dof.device for dof in self.dofs] @property - def search_bounds(self) -> np.array: + def search_domain(self) -> np.array: """ - Returns a (n_dof, 2) array of bounds. + Returns a (n_dof, 2) array of domain. """ - return np.array([dof._search_bounds for dof in self.dofs]) + return np.array([dof._search_domain for dof in self.dofs]) @property - def trust_bounds(self) -> np.array: + def trust_domain(self) -> np.array: """ - Returns a (n_dof, 2) array of bounds. + Returns a (n_dof, 2) array of domain. """ - return np.array([dof._trust_bounds for dof in self.dofs]) + return np.array([dof._trust_domain for dof in self.dofs]) def add(self, dof): _validate_dofs([*self.dofs, dof]) diff --git a/src/blop/objectives.py b/src/blop/objectives.py index 8c29d73..97aa1a4 100644 --- a/src/blop/objectives.py +++ b/src/blop/objectives.py @@ -12,9 +12,10 @@ OBJ_FIELD_TYPES = { "description": "object", + "type": "str", "target": "object", "active": "bool", - "trust_bounds": "object", + "trust_domain": "object", "active": "bool", "weight": "bool", "units": "object", @@ -26,6 +27,8 @@ "latent_groups": "object", } +ALLOWED_OBJ_TYPES = ["continuous", "binary", "ordinal", "categorical"] + class DuplicateNameError(ValueError): ... @@ -73,11 +76,12 @@ class Objective: name: str description: str = "" + type: str = "continuous" target: Union[Tuple[float, float], float, str] = "max" log: bool = False weight: float = 1.0 active: bool = True - trust_bounds: Tuple[float, float] or None = None + trust_domain: Tuple[float, float] or None = None min_noise: float = DEFAULT_MIN_NOISE_LEVEL max_noise: float = DEFAULT_MAX_NOISE_LEVEL units: str = None @@ -95,10 +99,10 @@ def __post_init__(self): self.use_as_constraint = True if isinstance(self.target, tuple) else False @property - def _trust_bounds(self): - if self.trust_bounds is None: + def _trust_domain(self): + if self.trust_domain is None: return (0, np.inf) if self.log else (-np.inf, np.inf) - return self.trust_bounds + return self.trust_domain @property def label(self) -> str: @@ -109,7 +113,7 @@ def summary(self) -> pd.Series: series = pd.Series(index=list(OBJ_FIELD_TYPES.keys()), dtype="object") for attr in series.index: value = getattr(self, attr) - if attr == "trust_bounds": + if attr == "trust_domain": if value is None: value = (0, np.inf) if self.log else (-np.inf, np.inf) series[attr] = value @@ -117,15 +121,15 @@ def summary(self) -> pd.Series: @property def trust_lower_bound(self): - if self.trust_bounds is None: + if self.trust_domain is None: return 0 if self.log else -np.inf - return float(self.trust_bounds[0]) + return float(self.trust_domain[0]) @property def trust_upper_bound(self): - if self.trust_bounds is None: + if self.trust_domain is None: return np.inf - return float(self.trust_bounds[1]) + return float(self.trust_domain[1]) @property def noise(self) -> float: diff --git a/src/blop/tests/conftest.py b/src/blop/tests/conftest.py index 6d3e781..99558ea 100644 --- a/src/blop/tests/conftest.py +++ b/src/blop/tests/conftest.py @@ -51,8 +51,8 @@ def agent(db): """ dofs = [ - DOF(name="x1", search_bounds=(-8.0, 8.0)), - DOF(name="x2", search_bounds=(-8.0, 8.0)), + DOF(name="x1", search_domain=(-8.0, 8.0)), + DOF(name="x2", search_domain=(-8.0, 8.0)), ] objectives = [Objective(name="himmelblau", target="min")] @@ -85,8 +85,8 @@ def digestion(db, uid): return products dofs = [ - DOF(name="x1", search_bounds=(-5.0, 5.0)), - DOF(name="x2", search_bounds=(-5.0, 5.0)), + DOF(name="x1", search_domain=(-5.0, 5.0)), + DOF(name="x2", search_domain=(-5.0, 5.0)), ] objectives = [Objective(name="obj1", target="min"), Objective(name="obj2", target="min")] @@ -110,9 +110,9 @@ def agent_with_passive_dofs(db): """ dofs = [ - DOF(name="x1", search_bounds=(-5.0, 5.0)), - DOF(name="x2", search_bounds=(-5.0, 5.0)), - DOF(name="x3", search_bounds=(-5.0, 5.0), active=False), + DOF(name="x1", search_domain=(-5.0, 5.0)), + DOF(name="x2", search_domain=(-5.0, 5.0)), + DOF(name="x3", search_domain=(-5.0, 5.0), active=False), DOF(device=BrownianMotion(name="brownian1"), read_only=True), DOF(device=BrownianMotion(name="brownian2"), read_only=True, active=False), ] diff --git a/src/blop/tests/test_agent.py b/src/blop/tests/test_agent.py index 22446c3..9789635 100644 --- a/src/blop/tests/test_agent.py +++ b/src/blop/tests/test_agent.py @@ -8,3 +8,8 @@ def test_agent(agent, RE): def test_forget(agent, RE): RE(agent.learn("qr", n=4)) agent.forget(last=2) + + +def test_benchmark(agent, RE): + per_iter_learn_kwargs_list = [{"acq_func": "qr", "n": 64}, {"acq_func": "qei", "n": 2, "iterations": 2}] + RE(agent.benchmark(output_dir="/tmp/blop", iterations=1, per_iter_learn_kwargs_list=per_iter_learn_kwargs_list)) diff --git a/src/blop/tests/test_dofs.py b/src/blop/tests/test_dofs.py new file mode 100644 index 0000000..5ff3664 --- /dev/null +++ b/src/blop/tests/test_dofs.py @@ -0,0 +1,32 @@ +import pytest # noqa F401 + +from blop.dofs import DOF, DOFList + + +def test_dof_types(): + dof1 = DOF(description="A continuous DOF", type="continuous", name="x1", search_domain=[0, 5], units="mm") + dof2 = DOF( + description="A binary DOF", + type="binary", + name="x2", + search_domain=["in", "out"], + trust_domain=["in"], + units="is it in or out?", + ) + dof3 = DOF( + description="An ordinal DOF", + type="ordinal", + name="x3", + search_domain=["low", "medium", "high"], + trust_domain=["low", "medium"], + units="noise level", + ) + dof4 = DOF( + description="A categorical DOF", + type="categorical", + name="x4", + search_domain=["mango", "orange", "banana", "papaya"], + units="fruit", + ) + + dofs = DOFList([dof1, dof2, dof3, dof4]) # noqa