Last active
January 4, 2021 19:53
-
-
Save catnipan/07a60ba3420ef496744b9e4c3c60733c to your computer and use it in GitHub Desktop.
recurrence - a Rust macro example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Credit to https://danielkeep.github.io/practical-intro-to-macros.html | |
macro_rules! count_exprs { | |
() => (0); | |
($head:expr $(, $tail:expr)*) => (1 + count_exprs!($($tail),*)); | |
} | |
macro_rules! recurrence { | |
( $seq:ident [ $ind:ident ]: $sty:ty = $recur:expr, $($inits:expr),+) => { | |
{ | |
const MEMORY: usize = count_exprs!($($inits),+); | |
#[derive(Debug)] | |
struct Recurrence { | |
mem: [$sty; MEMORY], | |
pos: usize, | |
} | |
struct IndexOffset<'a> { | |
slice: &'a [$sty; MEMORY], | |
offset: usize, | |
} | |
impl<'a> std::ops::Index<usize> for IndexOffset<'a> { | |
type Output = $sty; | |
#[inline(always)] | |
fn index<'b>(&'b self, index: usize) -> &'b $sty { | |
let real_index = index + MEMORY - self.offset; | |
&self.slice[real_index] | |
} | |
} | |
impl Iterator for Recurrence { | |
type Item = $sty; | |
#[inline] | |
fn next(&mut self) -> Option<$sty> { | |
if self.pos < MEMORY { | |
let next_val = self.mem[self.pos]; | |
self.pos += 1; | |
Some(next_val) | |
} else { | |
let next_val = { | |
let $ind = self.pos; | |
let $seq = IndexOffset { slice: &self.mem, offset: $ind }; | |
$recur | |
}; | |
{ | |
use std::mem::swap; | |
let mut swap_tmp = next_val; | |
for i in (0..MEMORY).rev() { | |
swap(&mut swap_tmp, &mut self.mem[i]); | |
} | |
} | |
self.pos += 1; | |
Some(next_val) | |
} | |
} | |
} | |
Recurrence { mem: [$($inits),+], pos: 0 } | |
} | |
}; | |
} | |
fn main() { | |
let fib = recurrence![a[n]: u64 = a[n-1] + a[n-2], 0, 1]; | |
assert_eq!(fib.take(10).collect::<Vec<_>>(), vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34]); | |
let factorial = recurrence![a[n]: usize = a[n-1] * n, 1]; | |
assert_eq!(factorial.take(10).collect::<Vec<_>>(), vec![1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880]); | |
let my = recurrence![a[n]: usize = { | |
if n % 2 == 0 { | |
a[n - 1] + 1 | |
} else { | |
a[n - 1] * 2 | |
} | |
}, 3]; | |
assert_eq!(my.take(10).collect::<Vec<_>>(), vec![3, 6, 7, 14, 15, 30, 31, 62, 63, 126]); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment