Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow wildcards in StrCat list indexing #601

Merged
merged 3 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions src/hist/basehist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import fnmatch
import functools
import itertools
import operator
import typing
import warnings
Expand Down Expand Up @@ -346,11 +348,38 @@ def sort(
# This can only return Self, not float, etc., so we ignore the extra types here
return self[{axis: [bh.loc(x) for x in sorted_cats]}] # type: ignore[dict-item, return-value]

def _loc_shortcut(self, x: Any) -> Any:
def _convert_index_wildcards(self, x: Any, ax_id: str | int | None = None) -> Any:
"""
Convert some specific indices to location.
Convert wildcards to available indices before passing to bh.loc
"""

if not any(
isinstance(x, t) for t in [str, list]
): # Process only lists and strings
return x
_x = x if isinstance(x, list) else [x] # Convert to list if not already
if not all(isinstance(n, str) for n in _x):
return x
if any(any(special in pattern for special in ["*", "?"]) for pattern in _x):
available = [n for n in self.axes[ax_id] if isinstance(n, str)]
all_matches = []
for pattern in _x:
all_matches.append(
[k for k in available if fnmatch.fnmatch(k, pattern)]
)
matches = list(
dict.fromkeys(list(itertools.chain.from_iterable(all_matches)))
)
if len(matches) == 0:
raise ValueError(f"No matches found for {x}")
return matches
return x

def _loc_shortcut(self, x: Any, ax_id: str | int | None = None) -> Any:
"""
Convert some specific indices to location.
"""
x = self._convert_index_wildcards(x, ax_id)
if isinstance(x, list):
return [self._loc_shortcut(each) for each in x]
if isinstance(x, slice):
Expand Down Expand Up @@ -382,9 +411,7 @@ def _step_shortcut(x: Any) -> Any:
raise ValueError("The imaginary part should be an integer")
return bh.rebin(int(x.imag))

def _index_transform(
self, index: list[IndexingExpr] | IndexingExpr
) -> bh.IndexingExpr:
def _index_transform(self, index: list[IndexingExpr] | IndexingExpr) -> Any:
"""
Auxiliary function for __getitem__ and __setitem__.
"""
Expand All @@ -393,7 +420,7 @@ def _index_transform(
new_indices = {
(
self._name_to_index(k) if isinstance(k, str) else k
): self._loc_shortcut(v)
): self._loc_shortcut(v, k)
for k, v in index.items()
}
if len(new_indices) != len(index):
Expand All @@ -405,7 +432,7 @@ def _index_transform(
if not isinstance(index, tuple):
index = (index,) # type: ignore[assignment]

return tuple(self._loc_shortcut(v) for v in index) # type: ignore[union-attr, union-attr, union-attr, union-attr]
return tuple(self._loc_shortcut(v, i) for i, (v) in enumerate(index)) # type: ignore[arg-type]

def __getitem__( # type: ignore[override]
self, index: IndexingExpr
Expand Down
8 changes: 8 additions & 0 deletions tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,14 @@ def test_select_by_index_imag():
assert tuple(h[[8j, 7j]].axes[0]) == (8, 7)


@pytest.mark.filterwarnings("ignore:List indexing selection is experimental")
def test_select_by_index_wildcards():
h = hist.new.Reg(10, 0, 10).StrCat(["ABC", "BCD", "CDE", "DEF"]).Weight()
assert tuple(h[:, "*E*"].axes[1]) == ("CDE", "DEF")
assert tuple(h[:, ["*B*", "CDE"]].axes[1]) == ("ABC", "BCD", "CDE")
assert tuple(h[:, ["*B*", "?D?"]].axes[1]) == ("ABC", "BCD", "CDE")


def test_sorted_simple():
h = Hist.new.IntCat([4, 1, 2]).StrCat(["AB", "BCC", "BC"]).Double()
assert tuple(h.sort(0).axes[0]) == (1, 2, 4)
Expand Down
Loading