Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types for merge.py #1859

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 61 additions & 36 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from functools import partial, reduce, singledispatch
from itertools import repeat
from operator import and_, or_, sub
from typing import Literal, TypeVar
from typing import Generic, Literal, TypeVar
from warnings import warn

import numpy as np
Expand All @@ -36,11 +36,16 @@
from .index import _subset, make_slice

if typing.TYPE_CHECKING:
from collections.abc import Collection, Iterable, Sequence
from typing import Any
from collections.abc import Collection, Iterable, Iterator, Sequence
from typing import Any, Self, TypeGuard

Check warning on line 40 in src/anndata/_core/merge.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_core/merge.py#L39-L40

Added lines #L39 - L40 were not covered by tests

from pandas.api.extensions import ExtensionDtype

from anndata._core.aligned_mapping import AlignedMappingBase

Check warning on line 44 in src/anndata/_core/merge.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_core/merge.py#L44

Added line #L44 was not covered by tests

_Array = SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray

Check warning on line 46 in src/anndata/_core/merge.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_core/merge.py#L46

Added line #L46 was not covered by tests

K = TypeVar("K")
T = TypeVar("T")

###################
Expand All @@ -49,40 +54,40 @@


# Pretty much just for maintaining order of keys
class OrderedSet(MutableSet):
def __init__(self, vals=()):
class OrderedSet(MutableSet, Generic[T]):
def __init__(self, vals: Iterable[T] = ()) -> None:
self.dict = OrderedDict(zip(vals, repeat(None)))

def __contains__(self, val):
def __contains__(self, val: object) -> bool:
return val in self.dict

def __iter__(self):
def __iter__(self) -> Iterator[T]:
return iter(self.dict)

def __len__(self):
def __len__(self) -> int:
return len(self.dict)

def __repr__(self):
def __repr__(self) -> str:
return "OrderedSet: {" + ", ".join(map(str, self)) + "}"

def copy(self):
return OrderedSet(self.dict.copy())
def copy(self) -> Self:
return type(self)(self.dict.copy())

Check warning on line 74 in src/anndata/_core/merge.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_core/merge.py#L74

Added line #L74 was not covered by tests

def add(self, val):
def add(self, val: T) -> None:
self.dict[val] = None

def union(self, *vals) -> OrderedSet:
def union(self, *vals: Iterable[T]) -> Self:
return reduce(or_, vals, self)

def discard(self, val):
def discard(self, val: T) -> None:
if val in self:
del self.dict[val]

def difference(self, *vals) -> OrderedSet:
def difference(self, *vals: Iterable[T]) -> Self:
return reduce(sub, vals, self)


def union_keys(ds: Collection) -> OrderedSet:
def union_keys(ds: Collection[Iterable[T]]) -> OrderedSet[T]:
return reduce(or_, ds, OrderedSet())


Expand All @@ -94,11 +99,11 @@
"""Represents a missing value."""


def is_missing(v) -> bool:
def is_missing(v: object | MissingVal) -> TypeGuard[MissingVal]:
return v is MissingVal


def not_missing(v) -> bool:
def not_missing(v: T | MissingVal) -> TypeGuard[T]:
return v is not MissingVal


Expand Down Expand Up @@ -327,7 +332,7 @@


# TODO: open PR or feature request to cupy
def _cp_block_diag(mats, format=None, dtype=None):
def _cp_block_diag(mats: Iterable[CupyArray], format=None, dtype=None):
"""
Modified version of scipy.sparse.block_diag for cupy sparse.
"""
Expand Down Expand Up @@ -363,7 +368,7 @@
).asformat(format)


def _dask_block_diag(mats):
def _dask_block_diag(mats: list[DaskArray]) -> DaskArray:
from itertools import permutations

import dask.array as da
Expand Down Expand Up @@ -511,7 +516,7 @@
Together with `old_pos` this forms a mapping.
"""

def __init__(self, old_idx, new_idx):
def __init__(self, old_idx: pd.Index, new_idx: pd.Index):
self.old_idx = old_idx
self.new_idx = new_idx
self.no_change = new_idx.equals(old_idx)
Expand All @@ -524,10 +529,14 @@
self.new_pos = new_pos[mask]
self.old_pos = old_pos[mask]

def __call__(self, el, *, axis=1, fill_value=None):
def __call__(
self, el: _Array, *, axis: Literal[0, 1] = 1, fill_value: object | None = None
) -> _Array:
return self.apply(el, axis=axis, fill_value=fill_value)

def apply(self, el, *, axis, fill_value=None):
def apply(
self, el: _Array, *, axis: Literal[0, 1], fill_value: object | None = None
) -> _Array:
"""
Reindex element so el[axis] is aligned to self.new_idx.

Expand Down Expand Up @@ -724,7 +733,7 @@
raise ValueError(msg)


def default_fill_value(els):
def default_fill_value(els: Iterable[_Array]) -> int | float:
"""Given some arrays, returns what the default fill value should be.

This is largely due to backwards compat, and might not be the ideal solution.
Expand All @@ -742,7 +751,7 @@
return np.nan


def gen_reindexer(new_var: pd.Index, cur_var: pd.Index):
def gen_reindexer(new_var: pd.Index, cur_var: pd.Index) -> Reindexer:
"""
Given a new set of var_names, and a current set, generates a function which will reindex
a matrix to be aligned with the new set.
Expand All @@ -763,14 +772,20 @@
return Reindexer(cur_var, new_var)


def np_bool_to_pd_bool_array(df: pd.DataFrame):
def np_bool_to_pd_bool_array(df: pd.DataFrame) -> pd.DataFrame:
for col_name, col_type in dict(df.dtypes).items():
if col_type is np.dtype(bool):
df[col_name] = pd.array(df[col_name].values)
return df


def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
def concat_arrays(
arrays: Iterable[_Array],
reindexers: Sequence[Reindexer] | Sequence[Callable[[_Array], _Array]],
axis: Literal[0, 1] = 0,
index: pd.Index | None = None,
fill_value: object | None = None,
):
arrays = list(arrays)
if fill_value is None:
fill_value = default_fill_value(arrays)
Expand Down Expand Up @@ -897,7 +912,7 @@
return reindexers


def gen_outer_reindexers(els, shapes, new_index: pd.Index, *, axis=0):
def gen_outer_reindexers(els, shapes) -> list[Reindexer] | list[Callable[[T], T]]:
if all(isinstance(el, pd.DataFrame) for el in els if not_missing(el)):
reindexers = [
(lambda x: x)
Expand Down Expand Up @@ -941,7 +956,7 @@

def missing_element(
n: int,
els: list[SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray],
els: list[_Array],
axis: Literal[0, 1] = 0,
fill_value: Any | None = None,
off_axis_size: int = 0,
Expand All @@ -960,15 +975,20 @@


def outer_concat_aligned_mapping(
mappings, *, reindexers=None, index=None, axis=0, fill_value=None
):
result = {}
mappings: Collection[AlignedMappingBase],
*,
reindexers: Sequence[Reindexer] | Sequence[Callable[[T], T]] | None = None,
index: pd.Index | None = None,
axis: Literal[0, 1] = 0,
fill_value: object | None = None,
) -> dict[str, _Array]:
result: dict[str, _Array] = {}
ns = [m.parent.shape[axis] for m in mappings]

for k in union_keys(mappings):
els = [m.get(k, MissingVal) for m in mappings]
if reindexers is None:
cur_reindexers = gen_outer_reindexers(els, ns, new_index=index, axis=axis)
cur_reindexers = gen_outer_reindexers(els, ns)
else:
cur_reindexers = reindexers

Expand Down Expand Up @@ -1004,8 +1024,8 @@

def concat_pairwise_mapping(
mappings: Collection[Mapping], shapes: Collection[int], join_keys=intersect_keys
):
result = {}
) -> dict[str, _Array]:
result: dict[str, _Array] = {}
if any(any(isinstance(v, SpArray) for v in m.values()) for m in mappings):
sparse_class = sparse.csr_array
else:
Expand Down Expand Up @@ -1067,7 +1087,12 @@


# TODO: Resolve https://github.com/scverse/anndata/issues/678 and remove this function
def concat_Xs(adatas, reindexers, axis, fill_value):
def concat_Xs(
adatas: Iterable[AnnData],
reindexers: Sequence[Reindexer] | Sequence[Callable[[_Array], _Array]],
axis: Literal[0, 1],
fill_value: object | None,
):
"""
Shimy until support for some missing X's is implemented.

Expand Down
Loading