Skip to content

Commit 359928a

Browse files
committed
feat : implemented broken-axis signed
- add additional space for "show" option, change the tick labels - using slash for "hint", "show" options
1 parent e832601 commit 359928a

File tree

5 files changed

+256
-2
lines changed

5 files changed

+256
-2
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ ci:
33

44
repos:
55
- repo: https://github.com/pre-commit/pre-commit-hooks
6-
rev: v4.3.0
6+
rev: v4.4.0
77
hooks:
88
- id: check-added-large-files
99
- id: check-case-conflict
@@ -23,7 +23,7 @@ repos:
2323
args: ["--include-version-classifiers", "--max-py-version=3.11"]
2424

2525
- repo: https://github.com/PyCQA/isort
26-
rev: 5.10.1
26+
rev: 5.12.0
2727
hooks:
2828
- id: isort
2929
args: ["-a", "from __future__ import annotations"]

src/mplhep/plot.py

+206
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def histplot(
7070
edges=True,
7171
binticks=False,
7272
ax=None,
73+
flow=None,
7374
**kwargs,
7475
):
7576
"""
@@ -130,6 +131,8 @@ def histplot(
130131
Attempts to draw x-axis ticks coinciding with bin boundaries if feasible.
131132
ax : matplotlib.axes.Axes, optional
132133
Axes object (if None, last one is fetched or one is created)
134+
flow : str, optional {None, "show", "sum", "hint"}
135+
Whether plot the under/overflow bin. If "show", add additional under/overflow bin. If "sum", add the under/overflow bin content to first/last bin.
133136
**kwargs :
134137
Keyword arguments passed to underlying matplotlib functions -
135138
{'step', 'fill_between', 'errorbar'}.
@@ -171,6 +174,42 @@ def histplot(
171174
plottables = [
172175
Plottable(h.values(), edges=final_bins, variances=h.variances()) for h in hists
173176
]
177+
# Show under/overflow bins
178+
# "show": Add additional bin with 5 times bin width
179+
if flow == "show":
180+
plottables = []
181+
final_bins = np.array(
182+
[
183+
final_bins[0] - (final_bins[-1] - final_bins[0]) * 0.08,
184+
final_bins[0] - (final_bins[-1] - final_bins[0]) * 0.03,
185+
*final_bins,
186+
final_bins[-1] + (final_bins[-1] - final_bins[0]) * 0.03,
187+
final_bins[-1] + (final_bins[-1] - final_bins[0]) * 0.08,
188+
]
189+
)
190+
for h in hists:
191+
value, variance = h.view(flow=True)["value"], h.view(flow=True)["variance"]
192+
value, variance = np.insert(value, -1, np.nan), np.insert(
193+
variance, -1, np.nan
194+
)
195+
value, variance = np.insert(value, 1, np.nan), np.insert(
196+
variance, 1, np.nan
197+
)
198+
plottables.append(Plottable(value, edges=final_bins, variances=variance))
199+
# "sum": Add under/overflow bin to first/last bin
200+
elif flow == "sum":
201+
plottables = []
202+
for h in hists:
203+
value, variance = h.view()["value"], h.view()["variance"]
204+
value[0], value[-1] = (
205+
value[0] + h.view(flow=True)["value"][0],
206+
value[-1] + h.view(flow=True)["value"][-1],
207+
)
208+
variance[0], variance[-1] = (
209+
variance[0] + h.view(flow=True)["variance"][0],
210+
variance[-1] + h.view(flow=True)["variance"][-1],
211+
)
212+
plottables.append(Plottable(value, edges=final_bins, variances=variance))
174213

175214
if w2 is not None:
176215
for _w2, _plottable in zip(
@@ -397,12 +436,73 @@ def iterable_not_string(arg):
397436
if binticks:
398437
_slice = int(round(float(len(final_bins)) / len(ax.get_xticks()))) + 1
399438
ax.set_xticks(final_bins[::_slice])
439+
elif flow == "show":
440+
if binticks:
441+
_slice = int(round(float(len(final_bins)) / len(ax.get_xticks()))) + 1
442+
ax.set_xticks(final_bins[::_slice])
400443
else:
401444
ax.set_xticks(_bin_centers)
402445
ax.set_xticklabels(xtick_labels)
403446

404447
if x_axes_label:
405448
ax.set_xlabel(x_axes_label)
449+
if flow == "hint" or flow == "show":
450+
underflow, overflow = 0.0, 0.0
451+
for h in hists:
452+
underflow = underflow + h.view(flow=True)["value"][0]
453+
overflow = overflow + h.view(flow=True)["value"][-1]
454+
d = 0.9 # proportion of vertical to horizontal extent of the slanted line
455+
trans = mpl.transforms.blended_transform_factory(ax.transData, ax.transAxes)
456+
kwargs = dict(
457+
marker=[(-0.5, -d), (0.5, d)],
458+
markersize=15,
459+
linestyle="none",
460+
color="k",
461+
mec="k",
462+
mew=1,
463+
clip_on=False,
464+
transform=trans,
465+
)
466+
xticks = ax.get_xticks().tolist()
467+
if underflow > 0.0:
468+
if flow == "hint":
469+
ax.plot(
470+
[
471+
final_bins[0] - (final_bins[-3] - final_bins[2]) * 0.03,
472+
final_bins[0],
473+
],
474+
[0, 0],
475+
**kwargs,
476+
)
477+
if flow == "show":
478+
ax.plot(
479+
[final_bins[1], final_bins[2]],
480+
[0, 0],
481+
**kwargs,
482+
)
483+
xticks[0] = ""
484+
xticks[1] = f"<{final_bins[2]}"
485+
486+
ax.set_xticklabels(xticks)
487+
if overflow > 0.0:
488+
if flow == "hint":
489+
ax.plot(
490+
[
491+
final_bins[-1],
492+
final_bins[-1] + (final_bins[-3] - final_bins[2]) * 0.03,
493+
],
494+
[0, 0],
495+
**kwargs,
496+
)
497+
if flow == "show":
498+
ax.plot(
499+
[final_bins[-3], final_bins[-2]],
500+
[0, 0],
501+
**kwargs,
502+
)
503+
xticks[-1] = ""
504+
xticks[-2] = f">{final_bins[-3]}"
505+
ax.set_xticklabels(xticks)
406506

407507
return return_artists
408508

@@ -420,6 +520,7 @@ def hist2dplot(
420520
cmin=None,
421521
cmax=None,
422522
ax=None,
523+
flow=None,
423524
**kwargs,
424525
):
425526
"""
@@ -460,6 +561,8 @@ def hist2dplot(
460561
Colorbar maximum.
461562
ax : matplotlib.axes.Axes, optional
462563
Axes object (if None, last one is fetched or one is created)
564+
flow : str, optional {None, "show", "sum","hint"}
565+
Whether plot the under/overflow bin. If "show", add additional under/overflow bin. If "sum", add the under/overflow bin content to first/last bin. "hint" would highlight the bins with under/overflow contents
463566
**kwargs :
464567
Keyword arguments passed to underlying matplotlib function - pcolormesh.
465568
@@ -482,6 +585,39 @@ def hist2dplot(
482585
H = hist.values()
483586
xbins, xtick_labels = get_plottable_protocol_bins(hist.axes[0])
484587
ybins, ytick_labels = get_plottable_protocol_bins(hist.axes[1])
588+
# Show under/overflow bins
589+
# "show": Add additional bin with 2 times bin width
590+
if flow == "show":
591+
H = hist.view(flow=True)["value"]
592+
593+
xbins = np.array(
594+
[
595+
xbins[0] - (xbins[-1] - xbins[0]) * 0.08,
596+
xbins[0] - (xbins[-1] - xbins[0]) * 0.03,
597+
*xbins,
598+
xbins[-1] + (xbins[-1] - xbins[0]) * 0.03,
599+
xbins[-1] + (xbins[-1] - xbins[0]) * 0.08,
600+
]
601+
)
602+
ybins = np.array(
603+
[
604+
ybins[0] - (ybins[-1] - ybins[0]) * 0.08,
605+
ybins[0] - (ybins[-1] - ybins[0]) * 0.03,
606+
*ybins,
607+
ybins[-1] + (ybins[-1] - ybins[0]) * 0.03,
608+
ybins[-1] + (ybins[-1] - ybins[0]) * 0.08,
609+
]
610+
)
611+
H = np.insert(H, (1, -1), np.nan, axis=-1)
612+
H = np.insert(H, (1, -1), np.full(np.shape(H)[1], np.nan), axis=0)
613+
614+
if flow == "sum":
615+
H[0, 0], H[-1, -1], H[0, -1], H[-1, 0] = (
616+
hist.view(flow=True)["value"][0, 0] + H[0, 0],
617+
hist.view(flow=True)["value"][-1, -1] + H[-1, -1],
618+
hist.view(flow=True)["value"][0, -1] + H[0, -1],
619+
hist.view(flow=True)["value"][-1, 0] + H[-1, 0],
620+
)
485621
xbin_centers = xbins[1:] - np.diff(xbins) / float(2)
486622
ybin_centers = ybins[1:] - np.diff(ybins) / float(2)
487623

@@ -536,6 +672,76 @@ def hist2dplot(
536672
cb_obj = None
537673

538674
plt.sca(ax)
675+
if flow == "hint" or flow == "show":
676+
d = 0.9 # proportion of vertical to horizontal extent of the slanted line
677+
trans = mpl.transforms.blended_transform_factory(ax.transData, ax.transAxes)
678+
kwargs = dict(
679+
marker=[(-0.5, -d), (0.5, d)],
680+
markersize=15,
681+
linestyle="none",
682+
color="k",
683+
mec="k",
684+
mew=1,
685+
clip_on=False,
686+
)
687+
xticks = ax.get_xticks().tolist()
688+
yticks = ax.get_yticks().tolist()
689+
if hist.view(flow=True)["value"][0, 0] > 0.0:
690+
if flow == "hint":
691+
ax.plot(
692+
[xbins[0] - (xbins[-3] - xbins[2]) * 0.03, xbins[0]],
693+
[0, 0],
694+
transform=trans,
695+
**kwargs,
696+
)
697+
if flow == "show":
698+
ax.plot([xbins[1], xbins[2]], [0, 0], transform=trans, **kwargs)
699+
ax.plot([xbins[0], xbins[0]], [ybins[1], ybins[2]], **kwargs)
700+
xticks[0] = ""
701+
xticks[1] = f"<{xbins[1]}"
702+
ax.set_xticklabels(xticks)
703+
if hist.view(flow=True)["value"][-1, 0] > 0.0:
704+
if flow == "hint":
705+
ax.plot(
706+
[xbins[-1] + (xbins[-3] - xbins[2]) * 0.03, xbins[-1]],
707+
[0, 0],
708+
transform=trans,
709+
**kwargs,
710+
)
711+
if flow == "show":
712+
ax.plot([xbins[-3], xbins[-2]], [0, 0], transform=trans, **kwargs)
713+
ax.plot([xbins[-1], xbins[-1]], [ybins[1], ybins[2]], **kwargs)
714+
xticks[-1] = ""
715+
xticks[-2] = f">{xbins[-2]}"
716+
ax.set_xticklabels(xticks)
717+
if hist.view(flow=True)["value"][0, -1] > 0.0:
718+
if flow == "hint":
719+
ax.plot(
720+
[xbins[0], xbins[0] - (xbins[-3] - xbins[2]) * 0.03],
721+
[1, 1],
722+
transform=trans,
723+
**kwargs,
724+
)
725+
if flow == "show":
726+
ax.plot([xbins[1], xbins[2]], [1, 1], transform=trans, **kwargs)
727+
ax.plot([xbins[0], xbins[0]], [ybins[-3], ybins[-2]], **kwargs)
728+
yticks[0] = ""
729+
yticks[1] = f"<{ybins[1]}"
730+
ax.set_yticklabels(yticks)
731+
if hist.view(flow=True)["value"][-1, -1] > 0.0:
732+
if flow == "hint":
733+
ax.plot(
734+
[xbins[-1] + (xbins[-3] - xbins[2]) * 0.03, xbins[-1]],
735+
[1, 1],
736+
transform=trans,
737+
**kwargs,
738+
)
739+
if flow == "show":
740+
ax.plot([xbins[-3], xbins[-2]], [1, 1], transform=trans, **kwargs)
741+
ax.plot([xbins[-1], xbins[-1]], [ybins[-3], ybins[-2]], **kwargs)
742+
yticks[-1] = ""
743+
yticks[-2] = f">{ybins[-2]}"
744+
ax.set_yticklabels(yticks)
539745

540746
_labels: np.ndarray | None = None
541747
if isinstance(labels, bool):
47.1 KB
Loading

tests/baseline/test_histplot_flow.png

31.3 KB
Loading

tests/test_basic.py

+48
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44

5+
import hist
56
import matplotlib.pyplot as plt
67
import numpy as np
78
import pytest
@@ -110,6 +111,27 @@ def test_histplot_density():
110111
return fig
111112

112113

114+
@pytest.mark.mpl_image_compare(style="default")
115+
def test_histplot_flow():
116+
np.random.seed(0)
117+
h = hist.Hist(hist.axis.Regular(20, 5, 15, name="x"), hist.storage.Weight())
118+
h.fill(np.random.normal(10, 3, 400))
119+
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10))
120+
axs = axs.flatten()
121+
122+
hep.histplot(h, ax=axs[0], flow="hint")
123+
hep.histplot(h, ax=axs[1], flow="None")
124+
hep.histplot(h, ax=axs[2], flow="show")
125+
hep.histplot(h, ax=axs[3], flow="sum")
126+
127+
axs[0].set_title("Default", fontsize=18)
128+
axs[1].set_title("None", fontsize=18)
129+
axs[2].set_title("Show", fontsize=18)
130+
axs[3].set_title("Sum", fontsize=18)
131+
fig.subplots_adjust(hspace=0.1, wspace=0.1)
132+
return fig
133+
134+
113135
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
114136
def test_histplot_multiple():
115137
np.random.seed(0)
@@ -175,6 +197,32 @@ def test_hist2dplot():
175197
return fig
176198

177199

200+
@pytest.mark.mpl_image_compare(style="default")
201+
def test_hist2dplot_flow():
202+
np.random.seed(0)
203+
h = hist.Hist(
204+
hist.axis.Regular(20, 5, 15, name="x"),
205+
hist.axis.Regular(20, -5, 5, name="y"),
206+
hist.storage.Weight(),
207+
)
208+
h.fill(np.random.normal(10, 3, 400), np.random.normal(0, 4, 400))
209+
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
210+
axs = axs.flatten()
211+
212+
hep.hist2dplot(h, ax=axs[0], flow="hint")
213+
hep.hist2dplot(h, ax=axs[1], flow="None")
214+
hep.hist2dplot(h, ax=axs[2], flow="show")
215+
hep.hist2dplot(h, ax=axs[3], flow="sum")
216+
217+
axs[0].set_title("Default", fontsize=18)
218+
axs[1].set_title("None", fontsize=18)
219+
axs[2].set_title("Show", fontsize=18)
220+
axs[3].set_title("Sum", fontsize=18)
221+
fig.subplots_adjust(hspace=0.1, wspace=0.1)
222+
223+
return fig
224+
225+
178226
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
179227
def test_hist2dplot_inputs_nobin():
180228
np.random.seed(0)

0 commit comments

Comments
 (0)