Last active
October 10, 2023 06:14
-
-
Save altescy/8338a0874b6ce5b78bb65bdc025ceb5f to your computer and use it in GitHub Desktop.
Utility functions for iteration in Python.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import itertools | |
import math | |
from collections import abc | |
from typing import (Any, Callable, Generic, Iterable, Iterator, List, Optional, | |
TypeVar) | |
T = TypeVar("T") | |
class SizedIterator(Generic[T]): | |
""" | |
A wrapper for an iterator that knows its size. | |
Args: | |
iterator: The iterator. | |
size: The size of the iterator. | |
""" | |
def __init__(self, iterator: Iterator[T], size: int): | |
self.iterator = iterator | |
self.size = size | |
def __iter__(self) -> Iterator[T]: | |
return self.iterator | |
def __next__(self) -> T: | |
return next(self.iterator) | |
def __len__(self) -> int: | |
return self.size | |
def batched( | |
iterable: Iterable[T], batch_size: int, drop_last: bool = False | |
) -> Iterator[List[T]]: | |
""" | |
Batch an iterable into lists of the given size. | |
Args: | |
iterable: The iterable. | |
batch_size: The size of each batch. | |
drop_last: Whether to drop the last batch if it is smaller than the given size. | |
Returns: | |
An iterator over batches. | |
""" | |
def iterator() -> Iterator[List[T]]: | |
batch = [] | |
for item in iterable: | |
batch.append(item) | |
if len(batch) == batch_size: | |
yield batch | |
batch = [] | |
if batch and not drop_last: | |
yield batch | |
if isinstance(iterable, abc.Sized): | |
num_batches = ( | |
len(iterable) // batch_size | |
if drop_last | |
else math.ceil(len(iterable) / batch_size) | |
) | |
return SizedIterator(iterator(), num_batches) | |
return iterator() | |
def batched_iterator(iterable: Iterable[T], batch_size: int) -> Iterator[Iterator[T]]: | |
""" | |
Batch an iterable into iterators of the given size. | |
Args: | |
iterable: The iterable. | |
batch_size: The size of each batch. | |
Returns: | |
An iterator over batches. | |
""" | |
def iterator() -> Iterator[Iterator[T]]: | |
iterator = iter(iterable) | |
while True: | |
try: | |
subiterator = itertools.chain( | |
[next(iterator)], itertools.islice(iterator, batch_size - 1) | |
) | |
yield subiterator | |
consume(subiterator) | |
except StopIteration: | |
break | |
if isinstance(iterable, abc.Sized): | |
num_batches = math.ceil(len(iterable) / batch_size) | |
return SizedIterator(iterator(), num_batches) | |
return iterator() | |
def iter_with_callback( | |
iterable: Iterable[T], | |
callback: Callable[[T], Any], | |
) -> Iterator[T]: | |
""" | |
Iterate over an iterable and call a callback for each item. | |
Args: | |
iterable: The iterable. | |
callback: The callback to call for each item. | |
Returns: | |
An iterator over the iterable. | |
""" | |
def iterator() -> Iterator[T]: | |
for item in iterable: | |
yield item | |
callback(item) | |
if isinstance(iterable, abc.Sized): | |
return SizedIterator(iterator(), len(iterable)) | |
return iterator() | |
def consume(iterator: Iterator, n: Optional[int] = None) -> None: | |
""" | |
Advance the iterator n-steps ahead. If n is None, consume entirely. | |
Args: | |
iterator: The iterator. | |
n: The number of items to consume. If None, consume entirely. | |
""" | |
if n is None: | |
collections.deque(iterator, maxlen=0) | |
else: | |
next(itertools.islice(iterator, n, n), None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment