Skip to content

Commit

Permalink
Add itertools.batched v3.13 function (#177)
Browse files Browse the repository at this point in the history
* Implement itertools.batched

* Match upstream args/example, add docs, better error tests

---------

Co-authored-by: Stanley Kudrow <[email protected]>
Co-authored-by: Amethyst Reese <[email protected]>
  • Loading branch information
3 people authored Sep 1, 2024
1 parent 538f169 commit ad505f3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
1 change: 1 addition & 0 deletions aioitertools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from .itertools import (
accumulate,
batched,
chain,
combinations,
combinations_with_replacement,
Expand Down
30 changes: 27 additions & 3 deletions aioitertools/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import operator
from typing import Any, AsyncIterator, List, Optional, overload, Tuple

from .builtins import enumerate, iter, list, next, zip
from .builtins import enumerate, iter, list, next, tuple, zip
from .helpers import maybe_await
from .types import (
Accumulator,
Expand Down Expand Up @@ -66,6 +66,30 @@ async def mul(a, b):
yield total


async def batched(
iterable: AnyIterable[T],
n: int,
*,
strict: bool = False,
) -> AsyncIterator[Tuple[T, ...]]:
"""
Yield batches of values from the given iterable. The final batch may be shorter.
Example::
async for batch in batched(range(15), 5):
... # (0, 1, 2, 3, 4), (5, 6, 7, 8, 9), (10, 11, 12, 13, 14)
"""
if n < 1:
raise ValueError("n must be at least one")
aiterator = iter(iterable)
while batch := await tuple(islice(aiterator, n)):
if strict and len(batch) != n:
raise ValueError("batched: incomplete batch")
yield batch


class Chain:
def __call__(self, *itrs: AnyIterable[T]) -> AsyncIterator[T]:
"""
Expand Down Expand Up @@ -517,7 +541,7 @@ async def gen(k: int, q: asyncio.Queue) -> AsyncIterator[T]:
break
yield value

return tuple(gen(k, q) for k, q in builtins.enumerate(queues))
return builtins.tuple(gen(k, q) for k, q in builtins.enumerate(queues))


async def zip_longest(
Expand Down Expand Up @@ -556,4 +580,4 @@ async def zip_longest(
raise value
if finished >= itr_count:
break
yield tuple(values)
yield builtins.tuple(values)
21 changes: 21 additions & 0 deletions aioitertools/tests/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ async def test_accumulate_empty(self):

self.assertEqual(values, [])

@async_test
async def test_batched(self):
test_matrix = [
([], 1, []),
([1, 2, 3], 1, [(1,), (2,), (3,)]),
([2, 3, 4], 2, [(2, 3), (4,)]),
([5, 6], 3, [(5, 6)]),
(ait.iter([-2, -1, 0, 1, 2]), 2, [(-2, -1), (0, 1), (2,)]),
]
for iterable, batch_size, answer in test_matrix:
result = [batch async for batch in ait.batched(iterable, batch_size)]

self.assertEqual(result, answer)

@async_test
async def test_batched_errors(self):
with self.assertRaisesRegex(ValueError, "n must be at least one"):
[batch async for batch in ait.batched([1], 0)]
with self.assertRaisesRegex(ValueError, "incomplete batch"):
[batch async for batch in ait.batched([1, 2, 3], 2, strict=True)]

@async_test
async def test_chain_lists(self):
it = ait.chain(slist, srange)
Expand Down

0 comments on commit ad505f3

Please sign in to comment.