Skip to content

Commit

Permalink
fix: pass ax_id to loc_shortcut for determining available keys
Browse files Browse the repository at this point in the history
  • Loading branch information
andrzejnovak committed Feb 18, 2025
1 parent 0d982af commit 8a9566a
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions src/hist/basehist.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def sort(
# This can only return Self, not float, etc., so we ignore the extra types here
return self[{axis: [bh.loc(x) for x in sorted_cats]}] # type: ignore[dict-item, return-value]

def _convert_index_wildcards(self, x: Any) -> Any:
def _convert_index_wildcards(self, x: Any, ax_id: str | int | None = None) -> Any:
"""
Convert wildcards to available indices before passing to bh.loc
"""
Expand All @@ -361,11 +361,7 @@ def _convert_index_wildcards(self, x: Any) -> Any:
if not all(isinstance(n, str) for n in _x):
return x
if any(any(special in pattern for special in ["*", "?"]) for pattern in _x):
available = list(
itertools.chain.from_iterable(
[n for n in ax if isinstance(n, str)] for ax in self.axes
)
)
available = [n for n in self.axes[ax_id] if isinstance(n, str)]
all_matches = []
for pattern in _x:
all_matches.append(
Expand All @@ -379,11 +375,11 @@ def _convert_index_wildcards(self, x: Any) -> Any:
return matches
return x

def _loc_shortcut(self, x: Any) -> Any:
def _loc_shortcut(self, x: Any, ax_id: str | int | None = None) -> Any:
"""
Convert some specific indices to location.
"""
x = self._convert_index_wildcards(x)
x = self._convert_index_wildcards(x, ax_id)
if isinstance(x, list):
return [self._loc_shortcut(each) for each in x]
if isinstance(x, slice):
Expand Down Expand Up @@ -415,9 +411,7 @@ def _step_shortcut(x: Any) -> Any:
raise ValueError("The imaginary part should be an integer")
return bh.rebin(int(x.imag))

def _index_transform(
self, index: list[IndexingExpr] | IndexingExpr
) -> bh.IndexingExpr:
def _index_transform(self, index: list[IndexingExpr] | IndexingExpr) -> Any:
"""
Auxiliary function for __getitem__ and __setitem__.
"""
Expand All @@ -426,7 +420,7 @@ def _index_transform(
new_indices = {
(
self._name_to_index(k) if isinstance(k, str) else k
): self._loc_shortcut(v)
): self._loc_shortcut(v, k)
for k, v in index.items()
}
if len(new_indices) != len(index):
Expand All @@ -438,7 +432,7 @@ def _index_transform(
if not isinstance(index, tuple):
index = (index,) # type: ignore[assignment]

return tuple(self._loc_shortcut(v) for v in index) # type: ignore[union-attr, union-attr, union-attr, union-attr]
return tuple(self._loc_shortcut(v, i) for i, (v) in enumerate(index)) # type: ignore[arg-type]

def __getitem__( # type: ignore[override]
self, index: IndexingExpr
Expand Down

0 comments on commit 8a9566a

Please sign in to comment.