Skip to content

Instantly share code, notes, and snippets.

@victor-iyi
Created August 25, 2022 19:29
Show Gist options
  • Save victor-iyi/d5c834777036a354e5dca11f67cd99aa to your computer and use it in GitHub Desktop.
Save victor-iyi/d5c834777036a354e5dca11f67cd99aa to your computer and use it in GitHub Desktop.
Implementation of flattening nested iterators
/// Returns an [`Flatten`] - an object that flattens a given iterable.
///
/// [`Flatten`]: struct.Flatten
pub fn flatten<I>(iter: I) -> Flatten<I::IntoIter>
where
I: IntoIterator,
I::Item: IntoIterator,
{
Flatten::new(iter.into_iter())
}
/// Flattens a given 2-dimensional iterator.
#[allow(dead_code)]
pub struct Flatten<O>
where
O: Iterator,
O::Item: IntoIterator,
{
outer: O,
front_iter: Option<<O::Item as IntoIterator>::IntoIter>,
back_iter: Option<<O::Item as IntoIterator>::IntoIter>,
}
impl<O> Flatten<O>
where
O: Iterator,
O::Item: IntoIterator,
{
/// Create a new [`Flatten`] object.
///
/// [`Flatten`]: struct.Flatten
fn new(iter: O) -> Self {
Self {
outer: iter,
front_iter: None,
back_iter: None,
}
}
}
impl<O> Iterator for Flatten<O>
where
O: Iterator,
O::Item: IntoIterator,
{
type Item = <O::Item as IntoIterator>::Item;
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(front_iter) = &mut self.front_iter {
if let Some(front_inner) = front_iter.next() {
return Some(front_inner);
}
self.front_iter = None;
}
if let Some(next_inner) = self.outer.next() {
self.front_iter = Some(next_inner.into_iter());
} else {
return self.back_iter.as_mut()?.next();
}
}
}
}
impl<O> DoubleEndedIterator for Flatten<O>
where
O: DoubleEndedIterator,
O::Item: IntoIterator,
<O::Item as IntoIterator>::IntoIter: DoubleEndedIterator,
{
fn next_back(&mut self) -> Option<Self::Item> {
loop {
if let Some(ref mut back_iter) = self.back_iter {
if let Some(back_inner) = back_iter.next_back() {
return Some(back_inner);
}
self.back_iter = None;
}
if let Some(next_inner) = self.outer.next_back() {
self.back_iter = Some(next_inner.into_iter());
} else {
return self.front_iter.as_mut()?.next();
}
}
}
}
/// Extend Iterators to call `.flatten_ext(...)` directly.
trait FlattenExt: Iterator + Sized {
/// Returns a `Flatten` object that flattens a given iterable.
fn flatten_ext(self) -> Flatten<Self>
where
Self::Item: IntoIterator;
}
impl<T> FlattenExt for T
where
T: Iterator,
{
fn flatten_ext(self) -> Flatten<Self>
where
Self::Item: IntoIterator,
{
flatten(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty() {
let empty = std::iter::empty::<Vec<()>>();
assert_eq!(flatten(empty).count(), 0);
}
#[test]
fn empty_wide() {
let empty = vec![Vec::<()>::new(), vec![], vec![]];
assert_eq!(flatten(empty.into_iter()).count(), 0);
}
#[test]
fn one() {
let one = std::iter::once(vec!['a']);
assert_eq!(flatten(one).count(), 1);
}
#[test]
fn two() {
let two = std::iter::once(vec!['a', 'b']);
assert_eq!(flatten(two).count(), 2);
}
#[test]
fn two_wide() {
let two_wide = vec![vec!['a'], vec!['b']];
assert_eq!(flatten(two_wide).count(), 2);
}
#[test]
fn reverse() {
let items = std::iter::once(vec!['a', 'b']);
assert_eq!(flatten(items).rev().collect::<Vec<_>>(), vec!['b', 'a']);
}
#[test]
fn reverse_wide() {
let items = vec![vec!['a'], vec!['b']];
assert_eq!(flatten(items).rev().collect::<Vec<_>>(), vec!['b', 'a']);
let items = vec![vec!['a', 'b'], vec!['c', 'd']];
assert_eq!(
flatten(items).rev().collect::<Vec<_>>(),
vec!['d', 'c', 'b', 'a']
);
}
#[test]
fn both_ends() {
let mut iter = flatten(vec![vec!['a', 'b'], vec!['c', 'd']]);
assert_eq!(iter.next(), Some('a'));
assert_eq!(iter.next_back(), Some('d'));
assert_eq!(iter.next(), Some('b'));
assert_eq!(iter.next_back(), Some('c'));
assert_eq!(iter.next(), None);
assert_eq!(iter.next_back(), None);
}
#[test]
fn double_ended() {
let mut iter =
flatten(vec![vec!["a1", "a2", "a3"], vec!["b1", "b2", "b3"]]);
assert_eq!(iter.next(), Some("a1"));
assert_eq!(iter.next_back(), Some("b3"));
assert_eq!(iter.next(), Some("a2"));
assert_eq!(iter.next_back(), Some("b2"));
assert_eq!(iter.next(), Some("a3"));
assert_eq!(iter.next_back(), Some("b1"));
assert_eq!(iter.next(), None);
assert_eq!(iter.next_back(), None);
}
#[test]
fn inf() {
let mut iter = flatten((0..).map(|i| 0..i));
// 0 => 0..0 -> empty
// 1 => 0..1 -> [0]
// 2 => 0..2 -> [0, 1]
assert_eq!(iter.next(), Some(0));
assert_eq!(iter.next(), Some(0));
assert_eq!(iter.next(), Some(1));
}
#[test]
fn deep_flatten() {
let tripple_nest = vec![vec![vec![0, 1]]];
assert_eq!(flatten(flatten(tripple_nest)).count(), 2);
}
#[test]
fn flatten_ext() {
let tripple_nest = vec![vec![0, 1]];
assert_eq!(tripple_nest.into_iter().flatten_ext().count(), 2);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment