Last active
December 14, 2023 13:11
-
-
Save CrabNejonas/32e7bf4bed2a5644ff0f23b4840d0234 to your computer and use it in GitHub Desktop.
Ports the `intersection` algorithm defined between two `BTreeSet<T>`s to a `BTreeMap<K, V>` and a `BTreeSet<K>` essentially retaining only key-value pairs who's keys are part of the set.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cmp::Ordering; | |
use std::collections::{btree_map, btree_set, BTreeMap, BTreeSet}; | |
use std::iter::FusedIterator; | |
trait IntersectionExt<K, V> { | |
fn pick<'a>(&'a self, other: &'a BTreeSet<K>) -> Pick<'a, K, V> | |
where | |
K: Ord; | |
} | |
struct Pick<'a, K, V> { | |
inner: IntersectionInner<'a, K, V>, | |
} | |
enum IntersectionInner<'a, K, V> { | |
Answer(Option<(&'a K, &'a V)>), | |
SearchMap { | |
small_iter: btree_map::Iter<'a, K, V>, | |
large_set: &'a BTreeSet<K>, | |
}, | |
SearchSet { | |
small_iter: btree_set::Iter<'a, K>, | |
large_map: &'a BTreeMap<K, V>, | |
}, | |
Stitch { | |
map: btree_map::Iter<'a, K, V>, | |
set: btree_set::Iter<'a, K>, | |
}, | |
} | |
// This constant is used by functions that compare two sets. | |
// It estimates the relative size at which searching performs better | |
// than iterating, based on the benchmarks in | |
// https://github.com/ssomers/rust_bench_btreeset_intersection. | |
// It's used to divide rather than multiply sizes, to rule out overflow, | |
// and it's a power of two to make that division cheap. | |
const ITER_PERFORMANCE_TIPPING_SIZE_DIFF: usize = 16; | |
impl<K, V> IntersectionExt<K, V> for BTreeMap<K, V> { | |
fn pick<'a>(&'a self, set: &'a BTreeSet<K>) -> Pick<'a, K, V> | |
where | |
K: Ord, | |
{ | |
let (self_min, self_max) = if let (Some(self_min), Some(self_max)) = | |
(self.first_key_value(), self.last_key_value()) | |
{ | |
(self_min, self_max) | |
} else { | |
return Pick { | |
inner: IntersectionInner::Answer(None), | |
}; | |
}; | |
let (set_min, set_max) = if let (Some(set_min), Some(set_max)) = (set.first(), set.last()) { | |
(set_min, set_max) | |
} else { | |
return Pick { | |
inner: IntersectionInner::Answer(None), | |
}; | |
}; | |
Pick { | |
inner: match (self_min.0.cmp(set_max), self_max.0.cmp(set_min)) { | |
(Ordering::Greater, _) | (_, Ordering::Less) => IntersectionInner::Answer(None), | |
(Ordering::Equal, _) => IntersectionInner::Answer(Some(self_min)), | |
(_, Ordering::Equal) => IntersectionInner::Answer(Some(self_max)), | |
_ if self.len() <= set.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => { | |
IntersectionInner::SearchMap { | |
small_iter: self.iter(), | |
large_set: set, | |
} | |
} | |
_ if set.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => { | |
IntersectionInner::SearchSet { | |
small_iter: set.iter(), | |
large_map: self, | |
} | |
} | |
_ => IntersectionInner::Stitch { | |
map: self.iter(), | |
set: set.iter(), | |
}, | |
}, | |
} | |
} | |
} | |
impl<'a, K: Ord, V> Iterator for Pick<'a, K, V> { | |
type Item = (&'a K, &'a V); | |
fn next(&mut self) -> Option<Self::Item> { | |
match &mut self.inner { | |
IntersectionInner::Answer(answer) => answer.take(), | |
IntersectionInner::SearchMap { | |
small_iter, | |
large_set, | |
} => loop { | |
let small_next = small_iter.next()?; | |
if large_set.contains(&small_next.0) { | |
return Some(small_next); | |
} | |
}, | |
IntersectionInner::SearchSet { | |
small_iter, | |
large_map, | |
} => loop { | |
let small_next = small_iter.next()?; | |
if let Some(v) = large_map.get(small_next) { | |
return Some((small_next, v)); | |
} | |
}, | |
IntersectionInner::Stitch { map, set } => { | |
let mut map_next = map.next()?; | |
let mut set_next = set.next()?; | |
loop { | |
match map_next.0.cmp(set_next) { | |
Ordering::Less => map_next = map.next()?, | |
Ordering::Greater => set_next = set.next()?, | |
Ordering::Equal => return Some(map_next), | |
} | |
} | |
} | |
} | |
} | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
match &self.inner { | |
IntersectionInner::Stitch { map, set } => { | |
(0, Some(std::cmp::min(map.len(), set.len()))) | |
} | |
IntersectionInner::SearchMap { small_iter, .. } => (0, Some(small_iter.len())), | |
IntersectionInner::SearchSet { small_iter, .. } => (0, Some(small_iter.len())), | |
IntersectionInner::Answer(None) => (0, Some(0)), | |
IntersectionInner::Answer(Some(_)) => (1, Some(1)), | |
} | |
} | |
fn min(mut self) -> Option<(&'a K, &'a V)> { | |
self.next() | |
} | |
} | |
impl<K: Ord, V> FusedIterator for Pick<'_, K, V> {} | |
#[cfg(test)] | |
mod test { | |
use super::*; | |
#[test] | |
fn pick() { | |
let mut map = BTreeMap::new(); | |
map.insert("a", "a"); | |
map.insert("b", "b"); | |
map.insert("c", "c"); | |
map.insert("d", "d"); | |
let mut set = BTreeSet::new(); | |
set.insert("a"); | |
set.insert("b"); | |
let pick = map.pick(&set); | |
let pick: Vec<_> = pick.collect(); | |
assert_eq!(pick.len(), 2); | |
assert_eq!(pick[0], (&"a", &"a")); | |
assert_eq!(pick[1], (&"b", &"b")); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment