From 33937c9138b32f06fa837125350869e649d634b4 Mon Sep 17 00:00:00 2001 From: Alexander Held Date: Mon, 16 Jan 2023 20:37:54 +0100 Subject: [PATCH] support pull comparisons --- src/cabinetry/visualize/__init__.py | 68 +++++++++++++++++++------- src/cabinetry/visualize/plot_result.py | 66 ++++++++++++++++++++----- 2 files changed, 102 insertions(+), 32 deletions(-) diff --git a/src/cabinetry/visualize/__init__.py b/src/cabinetry/visualize/__init__.py index 3f53b86f..be6a42d7 100644 --- a/src/cabinetry/visualize/__init__.py +++ b/src/cabinetry/visualize/__init__.py @@ -438,22 +438,29 @@ def correlation_matrix( def pulls( - fit_results: fit.FitResults, + fit_results: Union[fit.FitResults, List[fit.FitResults]], *, figure_folder: Union[str, pathlib.Path] = "figures", exclude: Optional[Union[str, List[str], Tuple[str, ...]]] = None, + fit_labels: Optional[Union[str, List[str]]] = None, close_figure: bool = True, save_figure: bool = True, ) -> mpl.figure.Figure: """Draws a pull plot of parameter results and uncertainties. Args: - fit_results (fit.FitResults): fit results, including correlation matrix and - parameter labels + fit_results (Union[fit.FitResults, List[fit.FitResults]]): fit results, + including correlation matrix and parameter labels, this can either be a + single result or a list of multiple results that are then compared with + each other (comparison of up to three results is supported) figure_folder (Union[str, pathlib.Path], optional): path to the folder to save figures in, defaults to "figures" exclude (Optional[Union[str, List[str], Tuple[str, ...]]], optional): parameter or parameters to exclude from plot, defaults to None (nothing excluded) + fit_labels (Optional[Union[str, List[str]]], optional): label(s) to identify fit + results with, has to be provided if more than one fit result is given, + defaults to None (then uses no label for the case of a single fit result + input) close_figure (bool, optional): whether to close figure, defaults to True save_figure (bool, optional): whether to save figure, defaults to True @@ -462,7 +469,19 @@ def pulls( """ # path is None if figure should not be saved figure_path = pathlib.Path(figure_folder) / "pulls.pdf" if save_figure else None - labels_np = np.asarray(fit_results.labels) + + # handle single / multiple fit results + if not isinstance(fit_results, list): + fit_results = [fit_results] + if not isinstance(fit_labels, list) and fit_labels is not None: + fit_labels = [fit_labels] + if len(fit_results) > 1 and fit_labels is None: + raise ValueError("fit labels need to be provided when comparing fit results") + elif fit_labels is not None and len(fit_results) != len(fit_labels): + raise ValueError( + f"found {len(fit_results)} fit result(s) but {len(fit_labels)} " + "label(s), they need to match" + ) if exclude is None: exclude_set = set() @@ -471,28 +490,39 @@ def pulls( else: exclude_set = set(exclude) - # exclude fixed parameters from pull plot - exclude_set.update( - [ - label - for i_np, label in enumerate(labels_np) - if fit_results.uncertainty[i_np] == 0.0 - ] - ) + bestfit = [] + uncertainty = [] + labels_np = [] + + # perform filtering per instance of fit_results + for fit_results_inst in fit_results: + labels_np_inst = np.asarray(fit_results_inst.labels) + + # exclude fixed parameters from pull plot + exclude_set.update( + [ + label + for i_np, label in enumerate(labels_np_inst) + if fit_results_inst.uncertainty[i_np] == 0.0 + ] + ) - # exclude staterror parameters from pull plot (they are centered at 1) - exclude_set.update([label for label in labels_np if label[0:10] == "staterror_"]) + # exclude staterror parameters from pull plot (they are centered at 1) + exclude_set.update( + [label for label in labels_np_inst if label[0:10] == "staterror_"] + ) - # filter out user-specified parameters - mask = [True if label not in exclude_set else False for label in labels_np] - bestfit = fit_results.bestfit[mask] - uncertainty = fit_results.uncertainty[mask] - labels_np = labels_np[mask] + # filter out user-specified parameters + mask = [True if label not in exclude_set else False for label in labels_np_inst] + bestfit.append(fit_results_inst.bestfit[mask]) + uncertainty.append(fit_results_inst.uncertainty[mask]) + labels_np.append(labels_np_inst[mask]) fig = plot_result.pulls( bestfit, uncertainty, labels_np, + fit_labels, figure_path=figure_path, close_figure=close_figure, ) diff --git a/src/cabinetry/visualize/plot_result.py b/src/cabinetry/visualize/plot_result.py index 82a3bde0..a012551c 100644 --- a/src/cabinetry/visualize/plot_result.py +++ b/src/cabinetry/visualize/plot_result.py @@ -73,19 +73,25 @@ def correlation_matrix( def pulls( - bestfit: np.ndarray, - uncertainty: np.ndarray, - labels: Union[List[str], np.ndarray], + bestfit: List[np.ndarray], + uncertainty: List[np.ndarray], + labels: List[np.ndarray], # TODO: this also supported list of strings before + fit_labels: Optional[List[str]], *, figure_path: Optional[pathlib.Path] = None, close_figure: bool = False, ) -> mpl.figure.Figure: - """Draws a pull plot. + """Draws a pull plot for one or multiple fits. + + Parameters are sorted alphabetically. Args: - bestfit (np.ndarray): best-fit parameter results - uncertainty (np.ndarray): parameter uncertainties - labels (Union[List[str], np.ndarray]): parameter names + bestfit (List[np.ndarray]): list of best-fit parameter results per fit + uncertainty (List[np.ndarray]): list of parameter uncertainties per fit + labels (List[np.ndarray]): list of parameter names per fit + fit_labels (Optional[List[str]]): list or fit labels, or None if + no labels are given (only supported for a single fit input, no check for + that performed) figure_path (Optional[pathlib.Path], optional): path where figure should be saved, or None to not save it, defaults to None close_figure (bool, optional): whether to close each figure immediately after @@ -95,24 +101,58 @@ def pulls( Returns: matplotlib.figure.Figure: the pull figure """ - num_pars = len(bestfit) + # get the union of all parameters in the input + unique_labels = sorted(set.union(*[set(lab) for lab in labels])) + + num_pars = len(unique_labels) y_positions = np.arange(num_pars)[::-1] + # TODO: increase figure size if legend at top is used fig, ax = plt.subplots(figsize=(6, 1 + num_pars / 4), dpi=100, layout="constrained") - ax.errorbar(bestfit, y_positions, xerr=uncertainty, fmt="o", color="black") - ax.fill_between([-2, 2], -0.5, len(bestfit) - 0.5, color="yellow") - ax.fill_between([-1, 1], -0.5, len(bestfit) - 0.5, color="limegreen") - ax.vlines(0, -0.5, len(bestfit) - 0.5, linestyles="dotted", color="black") + # https://matplotlib.org/stable/gallery/color/named_colors.html#sphx-glr-gallery-color-named-colors-py + colors = ["black", "crimson", "mediumblue"] + + for i_fit, (bests, uncs, labs) in enumerate(zip(bestfit, uncertainty, labels)): + # for all labels in this specific fit results instance, find how they correspond + # to the sorted union that is used on the axis to get proper vertical positions + y_pos_reordered = [] + for lab in labs: + idx = num_pars - unique_labels.index(lab) - 1 + y_pos_reordered.append( + idx - 0.2 + i_fit * 0.2 + ) # TODO: customize depending on number of fits plotted + ax.errorbar(bests, y_pos_reordered, xerr=uncs, fmt="o", color=colors[i_fit]) + + ax.fill_between([-2, 2], -0.5, num_pars - 0.5, color="yellow") + ax.fill_between([-1, 1], -0.5, num_pars - 0.5, color="limegreen") + ax.vlines(0, -0.5, num_pars - 0.5, linestyles="dotted", color="black") ax.set_xlim([-3, 3]) ax.set_xlabel(r"$\left(\hat{\theta} - \theta_0\right) / \Delta \theta$") ax.set_ylim([-0.5, num_pars - 0.5]) ax.set_yticks(y_positions) - ax.set_yticklabels(labels) + ax.set_yticklabels(unique_labels) ax.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator()) # minor ticks ax.tick_params(axis="both", which="major", pad=8) ax.tick_params(direction="in", top=True, right=True, which="both") + # legend + if fit_labels is not None: + custom_lines = [ + mpl.lines.Line2D([0], [0], color=colors[i_fit], marker="o") + for i_fit in range(len(fit_labels)) + ] + # https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html#legend-location + ax.legend( + custom_lines, + fit_labels, + frameon=False, + ncols=3, + bbox_to_anchor=(0.0, 1.0, 1.0, 0.05), # TODO: center legend for <3 inputs? + mode="expand", + bbox_transform=fig.transFigure, + ) + utils._save_and_close(fig, figure_path, close_figure) return fig