Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pull comparison plot #387

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 49 additions & 19 deletions src/cabinetry/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,22 +438,29 @@ def correlation_matrix(


def pulls(
fit_results: fit.FitResults,
fit_results: Union[fit.FitResults, List[fit.FitResults]],
*,
figure_folder: Union[str, pathlib.Path] = "figures",
exclude: Optional[Union[str, List[str], Tuple[str, ...]]] = None,
fit_labels: Optional[Union[str, List[str]]] = None,
close_figure: bool = True,
save_figure: bool = True,
) -> mpl.figure.Figure:
"""Draws a pull plot of parameter results and uncertainties.

Args:
fit_results (fit.FitResults): fit results, including correlation matrix and
parameter labels
fit_results (Union[fit.FitResults, List[fit.FitResults]]): fit results,
including correlation matrix and parameter labels, this can either be a
single result or a list of multiple results that are then compared with
each other (comparison of up to three results is supported)
figure_folder (Union[str, pathlib.Path], optional): path to the folder to save
figures in, defaults to "figures"
exclude (Optional[Union[str, List[str], Tuple[str, ...]]], optional): parameter
or parameters to exclude from plot, defaults to None (nothing excluded)
fit_labels (Optional[Union[str, List[str]]], optional): label(s) to identify fit
results with, has to be provided if more than one fit result is given,
defaults to None (then uses no label for the case of a single fit result
input)
close_figure (bool, optional): whether to close figure, defaults to True
save_figure (bool, optional): whether to save figure, defaults to True

Expand All @@ -462,7 +469,19 @@ def pulls(
"""
# path is None if figure should not be saved
figure_path = pathlib.Path(figure_folder) / "pulls.pdf" if save_figure else None
labels_np = np.asarray(fit_results.labels)

# handle single / multiple fit results
if not isinstance(fit_results, list):
fit_results = [fit_results]
if not isinstance(fit_labels, list) and fit_labels is not None:
fit_labels = [fit_labels]
if len(fit_results) > 1 and fit_labels is None:
raise ValueError("fit labels need to be provided when comparing fit results")
elif fit_labels is not None and len(fit_results) != len(fit_labels):
raise ValueError(
f"found {len(fit_results)} fit result(s) but {len(fit_labels)} "
"label(s), they need to match"
)

if exclude is None:
exclude_set = set()
Expand All @@ -471,28 +490,39 @@ def pulls(
else:
exclude_set = set(exclude)

# exclude fixed parameters from pull plot
exclude_set.update(
[
label
for i_np, label in enumerate(labels_np)
if fit_results.uncertainty[i_np] == 0.0
]
)
bestfit = []
uncertainty = []
labels_np = []

# perform filtering per instance of fit_results
for fit_results_inst in fit_results:
labels_np_inst = np.asarray(fit_results_inst.labels)

# exclude fixed parameters from pull plot
exclude_set.update(
[
label
for i_np, label in enumerate(labels_np_inst)
if fit_results_inst.uncertainty[i_np] == 0.0
]
)

# exclude staterror parameters from pull plot (they are centered at 1)
exclude_set.update([label for label in labels_np if label[0:10] == "staterror_"])
# exclude staterror parameters from pull plot (they are centered at 1)
exclude_set.update(
[label for label in labels_np_inst if label[0:10] == "staterror_"]
)

# filter out user-specified parameters
mask = [True if label not in exclude_set else False for label in labels_np]
bestfit = fit_results.bestfit[mask]
uncertainty = fit_results.uncertainty[mask]
labels_np = labels_np[mask]
# filter out user-specified parameters
mask = [True if label not in exclude_set else False for label in labels_np_inst]
bestfit.append(fit_results_inst.bestfit[mask])
uncertainty.append(fit_results_inst.uncertainty[mask])
labels_np.append(labels_np_inst[mask])

fig = plot_result.pulls(
bestfit,
uncertainty,
labels_np,
fit_labels,
figure_path=figure_path,
close_figure=close_figure,
)
Expand Down
66 changes: 53 additions & 13 deletions src/cabinetry/visualize/plot_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,25 @@ def correlation_matrix(


def pulls(
bestfit: np.ndarray,
uncertainty: np.ndarray,
labels: Union[List[str], np.ndarray],
bestfit: List[np.ndarray],
uncertainty: List[np.ndarray],
labels: List[np.ndarray], # TODO: this also supported list of strings before
fit_labels: Optional[List[str]],
*,
figure_path: Optional[pathlib.Path] = None,
close_figure: bool = False,
) -> mpl.figure.Figure:
"""Draws a pull plot.
"""Draws a pull plot for one or multiple fits.

Parameters are sorted alphabetically.

Args:
bestfit (np.ndarray): best-fit parameter results
uncertainty (np.ndarray): parameter uncertainties
labels (Union[List[str], np.ndarray]): parameter names
bestfit (List[np.ndarray]): list of best-fit parameter results per fit
uncertainty (List[np.ndarray]): list of parameter uncertainties per fit
labels (List[np.ndarray]): list of parameter names per fit
fit_labels (Optional[List[str]]): list or fit labels, or None if
no labels are given (only supported for a single fit input, no check for
that performed)
figure_path (Optional[pathlib.Path], optional): path where figure should be
saved, or None to not save it, defaults to None
close_figure (bool, optional): whether to close each figure immediately after
Expand All @@ -95,24 +101,58 @@ def pulls(
Returns:
matplotlib.figure.Figure: the pull figure
"""
num_pars = len(bestfit)
# get the union of all parameters in the input
unique_labels = sorted(set.union(*[set(lab) for lab in labels]))

num_pars = len(unique_labels)
y_positions = np.arange(num_pars)[::-1]
# TODO: increase figure size if legend at top is used
fig, ax = plt.subplots(figsize=(6, 1 + num_pars / 4), dpi=100, layout="constrained")
ax.errorbar(bestfit, y_positions, xerr=uncertainty, fmt="o", color="black")

ax.fill_between([-2, 2], -0.5, len(bestfit) - 0.5, color="yellow")
ax.fill_between([-1, 1], -0.5, len(bestfit) - 0.5, color="limegreen")
ax.vlines(0, -0.5, len(bestfit) - 0.5, linestyles="dotted", color="black")
# https://matplotlib.org/stable/gallery/color/named_colors.html#sphx-glr-gallery-color-named-colors-py
colors = ["black", "crimson", "mediumblue"]

for i_fit, (bests, uncs, labs) in enumerate(zip(bestfit, uncertainty, labels)):
# for all labels in this specific fit results instance, find how they correspond
# to the sorted union that is used on the axis to get proper vertical positions
y_pos_reordered = []
for lab in labs:
idx = num_pars - unique_labels.index(lab) - 1
y_pos_reordered.append(
idx - 0.2 + i_fit * 0.2
) # TODO: customize depending on number of fits plotted
ax.errorbar(bests, y_pos_reordered, xerr=uncs, fmt="o", color=colors[i_fit])

ax.fill_between([-2, 2], -0.5, num_pars - 0.5, color="yellow")
ax.fill_between([-1, 1], -0.5, num_pars - 0.5, color="limegreen")
ax.vlines(0, -0.5, num_pars - 0.5, linestyles="dotted", color="black")

ax.set_xlim([-3, 3])
ax.set_xlabel(r"$\left(\hat{\theta} - \theta_0\right) / \Delta \theta$")
ax.set_ylim([-0.5, num_pars - 0.5])
ax.set_yticks(y_positions)
ax.set_yticklabels(labels)
ax.set_yticklabels(unique_labels)
ax.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator()) # minor ticks
ax.tick_params(axis="both", which="major", pad=8)
ax.tick_params(direction="in", top=True, right=True, which="both")

# legend
if fit_labels is not None:
custom_lines = [
mpl.lines.Line2D([0], [0], color=colors[i_fit], marker="o")
for i_fit in range(len(fit_labels))
]
# https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html#legend-location
ax.legend(
custom_lines,
fit_labels,
frameon=False,
ncols=3,
bbox_to_anchor=(0.0, 1.0, 1.0, 0.05), # TODO: center legend for <3 inputs?
mode="expand",
bbox_transform=fig.transFigure,
)

utils._save_and_close(fig, figure_path, close_figure)
return fig

Expand Down