Skip to content

Commit

Permalink
fix: correct sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
andrzejnovak committed Mar 7, 2025
1 parent 710f1e9 commit d97d0f2
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 22 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ test = [
"hist",
"pytest-mock",
"pytest-mpl",
"pytest-xdist",
"pytest>=6.0",
"scikit-hep-testdata",
"scipy>=1.1.0",
Expand Down
44 changes: 22 additions & 22 deletions src/mplhep/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,34 +198,21 @@ def histplot(
else get_histogram_axes_title(hists[0].axes[0])
)

plottables, flow_info = get_plottables(
hists,
bins=final_bins,
w2=w2,
w2method=w2method,
yerr=yerr,
stack=stack,
density=density,
binwnorm=binwnorm,
flow=flow,
)
flow_bins, underflow, overflow = flow_info

_labels: list[str | None]
if label is None:
_labels = [None] * len(plottables)
_labels = [None] * len(hists)
elif isinstance(label, str):
_labels = [label] * len(plottables)
_labels = [label] * len(hists)
elif not np.iterable(label):
_labels = [str(label)] * len(plottables)
_labels = [str(label)] * len(hists)
else:
_labels = [str(lab) for lab in label]

def iterable_not_string(arg):
return isinstance(arg, collections.abc.Iterable) and not isinstance(arg, str)

_chunked_kwargs: list[dict[str, Any]] = []
for _ in range(len(plottables)):
for _ in range(len(hists)):
_chunked_kwargs.append({})
for kwarg, kwarg_content in kwargs.items():
# Check if iterable
Expand All @@ -249,22 +236,35 @@ def iterable_not_string(arg):
if sort.split("_")[0] in ["l", "label"] and isinstance(_labels, list):
order = np.argsort(label) # [::-1]
elif sort.split("_")[0] in ["y", "yield"]:
_yields = [np.sum(_h.values) for _h in plottables] # type: ignore[var-annotated]
_yields = [np.sum(_h.values()) for _h in hists] # type: ignore[var-annotated]
order = np.argsort(_yields)
if len(sort.split("_")) == 2 and sort.split("_")[1] == "r":
order = order[::-1]
elif isinstance(sort, (list, np.ndarray)):
if len(sort) != len(plottables):
msg = f"Sort indexing array is of the wrong size - {len(sort)}, {len(plottables)} expected."
if len(sort) != len(hists):
msg = f"Sort indexing array is of the wrong size - {len(sort)}, {len(hists)} expected."
raise ValueError(msg)
order = np.asarray(sort)
else:
msg = f"Sort type: {sort} not understood."
raise ValueError(msg)
plottables = [plottables[ix] for ix in order]
hists = [hists[ix] for ix in order]
_chunked_kwargs = [_chunked_kwargs[ix] for ix in order]
_labels = [_labels[ix] for ix in order]

plottables, flow_info = get_plottables(
hists,
bins=final_bins,
w2=w2,
w2method=w2method,
yerr=yerr,
stack=stack,
density=density,
binwnorm=binwnorm,
flow=flow,
)
flow_bins, underflow, overflow = flow_info

##########
# Plotting
return_artists: list[StairsArtists | ErrorBarArtists] = []
Expand All @@ -274,7 +274,7 @@ def iterable_not_string(arg):
elif histtype == "barstep" and len(plottables) == 1:
histtype = "step"

# customize color cycle assignment when stacking to match legend
# # customize color cycle assignment when stacking to match legend
if stack:
plottables = plottables[::-1]
_chunked_kwargs = _chunked_kwargs[::-1]
Expand Down
Binary file added tests/baseline/test_histplot_sort_None.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_histplot_sort_label.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_histplot_sort_label_r.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_histplot_sort_sort5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_histplot_sort_yield.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_histplot_sort_yield_r.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 23 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,3 +707,26 @@ def test_histplot_inputs_pass(h, yerr, htype):
fig, ax = plt.subplots()
hep.histplot(h, bins, yerr=yerr, histtype=htype)
plt.close(fig)


@pytest.mark.parametrize(
"sort", [None, "label", "label_r", "yield", "yield_r", [0, 2, 1]]
)
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
def test_histplot_sort(sort):
np.random.seed(0)
h = hist.new.Reg(10, 0, 10).StrCat([], growth=True).Weight()
ixs = ["FOO", "BAR", "ZOO"]
for i, ix in enumerate(ixs):
h.fill(np.random.normal(2 + i * 1.5, 3, int(100 + 200 * i)), ix)

fig, ax = plt.subplots()
hep.histplot(
[h[:, ix] for ix in h.axes[1]],
label=h.axes[1],
stack=True,
histtype="fill",
sort=sort,
)
ax.legend()
return fig

0 comments on commit d97d0f2

Please sign in to comment.