Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(types): better typing for view returns #530

Merged
merged 1 commit into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
#### Typing changes

* Added Ellipsis support to typing. [#525][]
* Better typing for Views. [#530][]

[#526]: https://github.com/scikit-hep/boost-histogram/pull/526
[#529]: https://github.com/scikit-hep/boost-histogram/pull/529
[#530]: https://github.com/scikit-hep/boost-histogram/pull/530

### Version 1.0.0

Expand Down
6 changes: 4 additions & 2 deletions src/boost_histogram/_internal/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .storage import Double, Storage
from .typing import Accumulator, ArrayLike, CppHistogram, SupportsIndex
from .utils import cast, register, set_module
from .view import View, _to_view
from .view import MeanView, WeightedMeanView, WeightedSumView, _to_view

if TYPE_CHECKING:
from builtins import ellipsis
Expand Down Expand Up @@ -279,7 +279,9 @@ def ndim(self) -> int:
"""
return self._hist.rank() # type: ignore

def view(self, flow: bool = False) -> Union[np.ndarray, View]:
def view(
self, flow: bool = False
) -> Union[np.ndarray, WeightedSumView, WeightedMeanView, MeanView]:
"""
Return a view into the data, optionally with overflow turned on.
"""
Expand Down
27 changes: 21 additions & 6 deletions src/boost_histogram/_internal/view.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Mapping, MutableMapping, Type, Union
from typing import Any, Callable, ClassVar, Mapping, MutableMapping, Tuple, Type, Union

import numpy as np

Expand All @@ -8,6 +8,7 @@

class View(np.ndarray):
__slots__ = ()
_FIELDS: ClassVar[Tuple[str, ...]]

def __getitem__(self, ind: StrIndex) -> np.ndarray:
sliced = super().__getitem__(ind)
Expand All @@ -28,7 +29,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(\n " + repr(self.view(np.ndarray))[6:]

def __str__(self) -> str:
fields = ", ".join(self._FIELDS) # type: ignore
fields = ", ".join(self._FIELDS)
return "{self.__class__.__name__}: ({fields})\n{arr}".format(
self=self, fields=fields, arr=self.view(np.ndarray)
)
Expand All @@ -42,7 +43,7 @@ def __setitem__(self, ind: StrIndex, value: ArrayLike) -> None:
array = np.asarray(value)
if (
array.ndim == super().__getitem__(ind).ndim + 1
and len(self._FIELDS) == array.shape[-1] # type: ignore
and len(self._FIELDS) == array.shape[-1]
):
self.__setitem__(ind, self._PARENT._array(*np.moveaxis(array, -1, 0))) # type: ignore
elif self.dtype == array.dtype:
Expand Down Expand Up @@ -96,6 +97,9 @@ class WeightedSumView(View):
__slots__ = ()
_PARENT = WeightedSum

value: np.ndarray
variance: np.ndarray

# Could be implemented on master View
def __array_ufunc__(
self, ufunc: Ufunc, method: str, *inputs: Any, **kwargs: Any
Expand Down Expand Up @@ -181,7 +185,7 @@ def __array_ufunc__(

# ufuncs that are allowed to reduce
if ufunc in {np.add} and method == "reduce" and len(inputs) == 1:
results = (ufunc.reduce(self[field], **kwargs) for field in self._FIELDS) # type: ignore
results = (ufunc.reduce(self[field], **kwargs) for field in self._FIELDS)
return self._PARENT._make(*results) # type: ignore

# If unsupported, just pass through (will return not implemented)
Expand All @@ -198,6 +202,11 @@ class WeightedMeanView(View):
__slots__ = ()
_PARENT = WeightedMean

sum_of_weights: np.ndarray
sum_of_weights_squared: np.ndarray
value: np.ndarray
_sum_of_weighted_deltas_squared: np.ndarray

@property
def variance(self) -> np.ndarray:
with np.errstate(divide="ignore", invalid="ignore"):
Expand All @@ -212,16 +221,22 @@ class MeanView(View):
__slots__ = ()
_PARENT = Mean

count: np.ndarray
value: np.ndarray
sum_of_deltas_squared: np.ndarray

# Variance is a computation
@property
def variance(self) -> np.ndarray:
with np.errstate(divide="ignore", invalid="ignore"):
return self["sum_of_deltas_squared"] / (self["count"] - 1) # type: ignore


def _to_view(item: np.ndarray, value: bool = False) -> Union[np.ndarray, View]:
def _to_view(
item: np.ndarray, value: bool = False
) -> Union[np.ndarray, WeightedSumView, WeightedMeanView, MeanView]:
for cls in View.__subclasses__():
if cls._FIELDS == item.dtype.names: # type: ignore
if cls._FIELDS == item.dtype.names:
ret = item.view(cls)
if value and ret.shape:
return ret.value # type: ignore
Expand Down