Created
February 25, 2020 13:57
-
-
Save AngelicosPhosphoros/93ecc4e88771de874b9909ce73730383 to your computer and use it in GitHub Desktop.
Rust circular move using unsafe Hole
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cmp::{Ord, Ordering}; | |
use std::collections::HashMap; | |
use std::fmt::Debug; | |
use std::vec::Vec; | |
struct HeapEntry<TKey, TPriority> { | |
key: TKey, | |
priority: TPriority, | |
} | |
struct Hole<'a, T: 'a> { | |
data: &'a mut [T], | |
taken: std::mem::ManuallyDrop<T>, | |
position: usize, | |
} | |
impl<'a, T> Hole<'a, T> { | |
/// Create new hole in pos. pos must be within slice | |
#[inline(always)] | |
unsafe fn new(data: &'a mut [T], position: usize) -> Self { | |
debug_assert!(position < data.len()); | |
let taken = std::mem::ManuallyDrop::new(std::ptr::read(data.get_unchecked(position))); | |
Self { | |
data, | |
taken, | |
position, | |
} | |
} | |
#[inline(always)] | |
fn element(&self) -> &T { | |
&self.taken | |
} | |
#[inline(always)] | |
fn position(&self) -> usize { | |
self.position | |
} | |
#[inline(always)] | |
fn data_len(&self) -> usize { | |
self.data.len() | |
} | |
#[inline(always)] | |
unsafe fn get(&self, index: usize) -> &T { | |
debug_assert!(index != self.position); | |
debug_assert!(index < self.data.len()); | |
self.data.get_unchecked(index) | |
} | |
#[inline(always)] | |
unsafe fn move_to(&mut self, new_position: usize) { | |
debug_assert!(new_position < self.data.len()); | |
let new_ptr: *const T = self.data.get_unchecked(new_position); | |
let old_position = self.position; | |
let old_ptr: *mut T = self.data.get_unchecked_mut(old_position); | |
std::ptr::copy(new_ptr, old_ptr, 1); | |
self.position = new_position; | |
} | |
} | |
impl<'a, T> Drop for Hole<'a, T> { | |
#[inline(always)] | |
fn drop(&mut self) { | |
// fill the hole again | |
unsafe { | |
let pos = self.position; | |
std::ptr::copy_nonoverlapping(&*self.taken, self.data.get_unchecked_mut(pos), 1); | |
} | |
} | |
} | |
pub(crate) struct BinaryHeap<TKey, TPriority> | |
where | |
TPriority: Ord, | |
{ | |
data: Vec<HeapEntry<TKey, TPriority>>, | |
} | |
impl<TKey, TPriority: Ord> BinaryHeap<TKey, TPriority> { | |
pub(crate) fn new() -> Self { | |
Self { data: Vec::new() } | |
} | |
pub(crate) fn with_capacity(capacity: usize) -> Self { | |
Self { | |
data: Vec::with_capacity(capacity), | |
} | |
} | |
#[inline(always)] | |
pub fn reserve(&mut self, additional: usize) { | |
self.data.reserve(additional); | |
} | |
/// Puts key and priority in queue, returns its final position | |
/// Calls change_handler for every move of old values | |
#[inline(always)] | |
pub(crate) fn push<TChangeHandler: std::ops::FnMut(&TKey, usize)>( | |
&mut self, | |
key: TKey, | |
priority: TPriority, | |
change_handler: TChangeHandler, | |
) { | |
self.data.push(HeapEntry { key, priority }); | |
self.heapify_up(self.data.len() - 1, change_handler); | |
} | |
/// Removes item with the biggest priority | |
/// Time complexity - O(log n) swaps and change_handler calls | |
#[inline(always)] | |
pub(crate) fn pop<TChangeHandler: std::ops::FnMut(&TKey, usize)>( | |
&mut self, | |
change_handler: TChangeHandler, | |
) -> Option<(TKey, TPriority)> { | |
self.remove(0, change_handler) | |
} | |
#[inline(always)] | |
pub(crate) fn peek(&self) -> Option<(&TKey, &TPriority)> { | |
self.look_into(0) | |
} | |
/// Removes item at position and returns it | |
/// Time complexity - O(log n) swaps and change_handler calls | |
pub(crate) fn remove<TChangeHandler: std::ops::FnMut(&TKey, usize)>( | |
&mut self, | |
position: usize, | |
change_handler: TChangeHandler, | |
) -> Option<(TKey, TPriority)> { | |
if self.data.len() <= position { | |
return None; | |
} | |
if position == self.data.len() - 1 { | |
let result = self.data.pop().unwrap(); | |
return Some((result.key, result.priority)); | |
} | |
self.swap_items(position, self.data.len() - 1); | |
let result = self.data.pop().unwrap(); | |
self.heapify_down(position, change_handler); | |
Some((result.key, result.priority)) | |
} | |
#[inline(always)] | |
pub(crate) fn look_into(&self, position: usize) -> Option<(&TKey, &TPriority)> { | |
let entry = self.data.get(position)?; | |
Some((&entry.key, &entry.priority)) | |
} | |
/// Changes priority of queue item | |
pub(crate) fn change_priority<TChangeHandler: std::ops::FnMut(&TKey, usize)>( | |
&mut self, | |
position: usize, | |
updated: TPriority, | |
change_handler: TChangeHandler, | |
) { | |
if position >= self.data.len() { | |
panic!("Out of index during changing priority"); | |
} | |
let old = std::mem::replace(&mut self.data[position].priority, updated); | |
match old.cmp(&self.data[position].priority) { | |
Ordering::Less => { | |
self.heapify_up(position, change_handler); | |
} | |
Ordering::Equal => {} | |
Ordering::Greater => { | |
self.heapify_down(position, change_handler); | |
} | |
} | |
} | |
#[inline(always)] | |
pub(crate) fn len(&self) -> usize { | |
self.data.len() | |
} | |
#[inline(always)] | |
pub(crate) fn is_empty(&self) -> bool { | |
self.data.is_empty() | |
} | |
#[inline(always)] | |
pub(crate) fn clear(&mut self) { | |
self.data.clear() | |
} | |
fn heapify_up<TChangeHandler: std::ops::FnMut(&TKey, usize)>( | |
&mut self, | |
position: usize, | |
mut change_handler: TChangeHandler, | |
) { | |
debug_assert!(position < self.data.len(), "Out of index in heapify_up"); | |
if position == 0 { | |
change_handler(&self.data[position].key, position); | |
return; | |
} | |
let final_position = unsafe { | |
if position >= self.data.len() { | |
panic!( | |
"Call of heapify up with data len {} and position {}", | |
self.data.len(), | |
position | |
); | |
} | |
let mut hole = Hole::new(&mut self.data[..], position); | |
while hole.position() > 0 { | |
let old_pos = hole.position(); | |
let parent_pos = (hole.position() - 1) / 2; | |
if hole.get(parent_pos).priority < hole.element().priority { | |
hole.move_to(parent_pos); | |
change_handler(&hole.get(old_pos).key, old_pos); | |
} else { | |
break; | |
} | |
} | |
hole.position() | |
}; | |
change_handler(&self.data[final_position].key, final_position); | |
} | |
fn heapify_down<TChangeHandler: std::ops::FnMut(&TKey, usize)>( | |
&mut self, | |
position: usize, | |
mut change_handler: TChangeHandler, | |
) { | |
if position + 1 == self.data.len() { | |
change_handler(&self.data[position].key, position); | |
return; | |
} | |
let final_position = unsafe { | |
if position >= self.data.len() { | |
panic!( | |
"Call of heapify down with data len {} and position {}", | |
self.data.len(), | |
position | |
); | |
} | |
let mut hole = Hole::new(&mut self.data[..], position); | |
loop { | |
let max_child_idx = { | |
let child1 = hole.position() * 2 + 1; | |
let child2 = child1 + 1; | |
if child1 >= hole.data_len() { | |
break; | |
} | |
if child2 >= hole.data_len() | |
|| hole.get(child2).priority < hole.get(child1).priority | |
{ | |
child1 | |
} else { | |
child2 | |
} | |
}; | |
if hole.element().priority >= hole.get(max_child_idx).priority { | |
break; | |
} | |
let old_pos = hole.position(); | |
hole.move_to(max_child_idx); | |
change_handler(&hole.get(old_pos).key, old_pos); | |
} | |
hole.position() | |
}; | |
change_handler(&self.data[final_position].key, final_position); | |
} | |
#[inline(always)] | |
fn swap_items(&mut self, pos1: usize, pos2: usize) { | |
debug_assert!(pos1 < self.data.len(), "Out of index in first pos in swap"); | |
debug_assert!(pos2 < self.data.len(), "Out of index in second pos in swap"); | |
self.data.swap(pos1, pos2); | |
} | |
} | |
impl<TKey: std::hash::Hash + Clone + Eq, TPriority: Ord> BinaryHeap<TKey, TPriority> { | |
pub(crate) fn build_from_iterator<T: Iterator<Item = (TKey, TPriority)>>( | |
iter: T, | |
) -> (Self, HashMap<TKey, usize>) { | |
let minimal_size = iter.size_hint().0; | |
let mut vec: Vec<HeapEntry<TKey, TPriority>> = if minimal_size > 0 { | |
Vec::with_capacity(minimal_size) | |
} else { | |
Vec::new() | |
}; | |
let mut map: HashMap<TKey, usize> = if minimal_size > 0 { | |
HashMap::with_capacity(minimal_size) | |
} else { | |
HashMap::new() | |
}; | |
for (key, priority) in iter { | |
if let Some(&pos) = map.get(&key) { | |
vec[pos].priority = priority; | |
} else { | |
map.insert(key.clone(), vec.len()); | |
vec.push(HeapEntry { key, priority }); | |
} | |
} | |
let mut res = Self { data: vec }; | |
let heapify_start = std::cmp::min(res.data.len() / 2 + 2, res.data.len()); | |
for i in (0..heapify_start).rev() { | |
res.heapify_down(i, |_, _| {}); | |
} | |
for (i, entry) in res.data.iter().enumerate() { | |
*map.get_mut(&entry.key).unwrap() = i; | |
} | |
(res, map) | |
} | |
} | |
// Default implementations | |
impl<TKey: Clone, TPriority: Clone> Clone for HeapEntry<TKey, TPriority> { | |
fn clone(&self) -> Self { | |
Self { | |
key: self.key.clone(), | |
priority: self.priority.clone(), | |
} | |
} | |
} | |
impl<TKey: Copy, TPriority: Copy> Copy for HeapEntry<TKey, TPriority> {} | |
impl<TKey: Debug, TPriority: Debug> Debug for HeapEntry<TKey, TPriority> { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { | |
write!( | |
f, | |
"{{key: {:?}, priority: {:?}}}", | |
&self.key, &self.priority | |
) | |
} | |
} | |
impl<TKey: Clone, TPriority: Clone + Ord> Clone for BinaryHeap<TKey, TPriority> { | |
fn clone(&self) -> Self { | |
Self { | |
data: self.data.clone(), | |
} | |
} | |
} | |
impl<TKey: Debug, TPriority: Debug + Ord> Debug for BinaryHeap<TKey, TPriority> { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { | |
self.data.fmt(f) | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use std::cmp::Reverse; | |
use std::collections::{HashMap, HashSet}; | |
fn is_valid_heap<TK, TP: Ord>(heap: &BinaryHeap<TK, TP>) -> bool { | |
for (i, current) in heap.data.iter().enumerate().skip(1) { | |
let parent = &heap.data[(i - 1) / 2]; | |
if parent.priority < current.priority { | |
return false; | |
} | |
} | |
true | |
} | |
#[test] | |
fn test_heap_fill() { | |
let items = [ | |
70, 50, 0, 1, 2, 4, 6, 7, 9, 72, 4, 4, 87, 78, 72, 6, 7, 9, 2, -50, -72, -50, -42, -1, | |
-3, -13, | |
]; | |
let mut maximum = std::i32::MIN; | |
let mut heap = BinaryHeap::<(), i32>::new(); | |
assert!(heap.peek().is_none()); | |
assert!(is_valid_heap(&heap), "Heap state is invalid"); | |
for &x in items.iter() { | |
if x > maximum { | |
maximum = x; | |
} | |
heap.push((), x, |_, _| {}); | |
assert!( | |
is_valid_heap(&heap), | |
"Heap state is invalid after pushing {}", | |
x | |
); | |
assert!(heap.peek().is_some()); | |
let (_, &heap_max) = heap.peek().unwrap(); | |
assert_eq!(maximum, heap_max) | |
} | |
} | |
#[test] | |
fn test_change_logger() { | |
let items = [ | |
2, 3, 21, 22, 25, 29, 36, -90, -89, -88, -87, -83, 48, 50, 52, -69, -65, -55, 73, 75, | |
76, -53, 78, 81, -45, -41, 91, -34, -33, -31, -27, -22, -19, -8, -5, -3, | |
]; | |
let mut last_positions = HashMap::<i32, usize>::new(); | |
let mut heap = BinaryHeap::<i32, i32>::new(); | |
let heap_ptr: *const BinaryHeap<i32, i32> = &heap; | |
let mut on_pos_change = |key: &i32, position: usize| { | |
// Hack to avoid borrow checker | |
let heap_local = unsafe { &*heap_ptr }; | |
assert_eq!(*heap_local.look_into(position).unwrap().0, *key); | |
assert_eq!( | |
heap_local.look_into(position).unwrap().0, | |
heap_local.look_into(position).unwrap().1 | |
); | |
last_positions.insert(*key, position); | |
}; | |
for &x in items.iter() { | |
heap.push(x, x, &mut on_pos_change); | |
} | |
for &x in items.iter() { | |
assert!( | |
last_positions.contains_key(&x), | |
"Not for all items change_handler called" | |
); | |
let position = last_positions[&x]; | |
assert_eq!( | |
heap.look_into(position).unwrap().0, | |
heap.look_into(position).unwrap().1 | |
); | |
assert_eq!(*heap.look_into(position).unwrap().0, x); | |
} | |
let mut removed = HashSet::<i32>::new(); | |
loop { | |
let mut on_pos_change = |key: &i32, position: usize| { | |
// Hack to avoid borrow checker | |
let heap_local = unsafe { &*heap_ptr }; | |
assert_eq!(*heap_local.look_into(position).unwrap().0, *key); | |
assert_eq!( | |
heap_local.look_into(position).unwrap().0, | |
heap_local.look_into(position).unwrap().1 | |
); | |
last_positions.insert(*key, position); | |
}; | |
let popped = heap.pop(&mut on_pos_change); | |
if popped.is_none() { | |
break; | |
} | |
let (key, _) = popped.unwrap(); | |
last_positions.remove(&key); | |
removed.insert(key); | |
for x in items.iter().cloned().filter(|x| !removed.contains(x)) { | |
assert!( | |
last_positions.contains_key(&x), | |
"Not for all items change_handler called" | |
); | |
let position = last_positions[&x]; | |
assert_eq!( | |
heap.look_into(position).unwrap().0, | |
heap.look_into(position).unwrap().1 | |
); | |
assert_eq!(*heap.look_into(position).unwrap().0, x); | |
} | |
} | |
} | |
#[test] | |
fn test_pop() { | |
let mut items = [ | |
-16, 5, 11, -1, -34, -42, -5, -6, 25, -35, 11, 35, -2, 40, 42, 40, -45, -48, 48, -38, | |
-28, -33, -31, 34, -18, 25, 16, -33, -11, -6, -35, -38, 35, -41, -38, 31, -38, -23, 26, | |
44, 38, 11, -49, 30, 7, 13, 12, -4, -11, -24, -49, 26, 42, 46, -25, -22, -6, -42, 28, | |
45, -47, 8, 8, 21, 49, -12, -5, -33, -37, 24, -3, -26, 6, -13, 16, -40, -14, -39, -26, | |
12, -44, 47, 45, -41, -22, -11, 20, 43, -44, 24, 47, 40, 43, 9, 19, 12, -17, 30, -36, | |
-50, 24, -2, 1, 1, 5, -19, 21, -38, 47, 34, -14, 12, -30, 24, -2, -32, -10, 40, 34, 2, | |
-33, 9, -31, -3, -15, 28, 50, -37, 35, 19, 35, 13, -2, 46, 28, 35, -40, -19, -1, -33, | |
-42, -35, -12, 19, 29, 10, -31, -4, -9, 24, 15, -27, 13, 20, 15, 19, -40, -41, 40, -25, | |
45, -11, -7, -19, 11, -44, -37, 35, 2, -49, 11, -37, -14, 13, 41, 10, 3, 19, -32, -12, | |
-12, 33, -26, -49, -45, 24, 47, -29, -25, -45, -36, 40, 24, -29, 15, 36, 0, 47, 3, -45, | |
]; | |
let mut heap = BinaryHeap::<i32, i32>::new(); | |
for &x in items.iter() { | |
heap.push(x, x, |_, _| {}); | |
} | |
assert!(is_valid_heap(&heap), "Heap is invalid before pops"); | |
items.sort_unstable_by_key(|&x| Reverse(x)); | |
for &x in items.iter() { | |
assert_eq!(heap.pop(|_, _| {}), Some((x, x))); | |
assert!(is_valid_heap(&heap), "Heap is invalid after {}", x); | |
} | |
assert_eq!(heap.pop(|_, _| {}), None); | |
} | |
#[test] | |
fn test_change_priority() { | |
let pairs = [ | |
("first", 0), | |
("second", 1), | |
("third", 2), | |
("fourth", 3), | |
("fifth", 4), | |
]; | |
let mut heap = BinaryHeap::new(); | |
for (key, priority) in pairs.iter().cloned() { | |
heap.push(key, priority, |_, _| {}); | |
} | |
assert!(is_valid_heap(&heap), "Invalid before change"); | |
heap.change_priority(3, 10, |_, _| {}); | |
assert!(is_valid_heap(&heap), "Invalid after upping"); | |
heap.change_priority(2, -10, |_, _| {}); | |
assert!(is_valid_heap(&heap), "Invalid after lowering"); | |
} | |
#[test] | |
fn build_from_iterator() { | |
let data = [ | |
16, 16, 5, 20, 10, 12, 10, 8, 12, 2, 20, -1, -18, 5, -16, 1, 7, 3, 17, -20, -4, 3, -7, | |
-5, -8, 19, -19, -16, 3, 4, 17, 13, 3, 11, -9, 0, -10, -2, 16, 19, -12, -4, 19, 7, 16, | |
-19, -9, -17, 6, -16, -3, 11, -14, -15, -10, 13, 11, -14, 18, -8, -9, -4, 5, -4, 17, 6, | |
-16, -5, 12, 12, -3, 8, 5, -4, 7, 10, 7, -11, 18, -16, 18, 4, -15, -4, -13, 7, -14, | |
-16, -18, -10, 13, -1, -9, 0, -18, -4, -13, 16, 10, -20, 19, 20, 0, -9, -7, 14, 19, -8, | |
-18, -1, -17, -11, 13, 12, -15, 0, -18, 6, -13, -17, -3, 18, 2, 12, 12, 4, -14, -11, | |
-10, -9, 3, 14, 8, 7, 13, 13, -17, -9, -4, -19, -6, 1, 9, 5, 20, -9, -19, -20, -18, -8, | |
7, | |
]; | |
for len in 0..data.len() { | |
let (heap, map) = | |
BinaryHeap::<i32, i32>::build_from_iterator(data.iter().map(|&x| (x, x)).take(len)); | |
for (i, entry) in heap.data.iter().enumerate() { | |
assert_eq!(map[&entry.key], i); | |
} | |
assert_eq!(heap.len(), map.len()); | |
assert!(is_valid_heap(&heap), "Must be valid heap"); | |
} | |
} | |
#[test] | |
fn test_clear() { | |
let mut heap = BinaryHeap::new(); | |
for x in 0..5 { | |
heap.push(x, x, |_, _| {}); | |
} | |
assert!(!heap.is_empty(), "Heap must be non empty"); | |
heap.data.clear(); | |
assert!(heap.is_empty(), "Heap must be empty"); | |
assert_eq!(heap.pop(|_, _| {}), None); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment