Skip to content

Instantly share code, notes, and snippets.

@pohzipohzi
Created August 23, 2024 14:10
Show Gist options
  • Save pohzipohzi/81cf2ae92a00287fccdac1c2d98b7ae1 to your computer and use it in GitHub Desktop.
Save pohzipohzi/81cf2ae92a00287fccdac1c2d98b7ae1 to your computer and use it in GitHub Desktop.
minimal fft implementation
use std::{
f64::consts::PI,
fs::File,
io::Read,
iter::zip,
time::{Duration, Instant},
};
fn main() {
let mut t = [Duration::ZERO; 4];
let n = 10;
for _ in 0..n {
for (i, dur) in g().into_iter().enumerate() {
t[i] += dur;
}
}
println!(
"dft={}s fft={}s idft={}s ifft={}s",
t[0].as_secs_f32() / n as f32,
t[1].as_secs_f32() / n as f32,
t[2].as_secs_f32() / n as f32,
t[3].as_secs_f32() / n as f32
);
}
fn g() -> [Duration; 4] {
let a = gen_a();
let t0 = Instant::now();
let dft_res = dft(&a, false);
let t1 = Instant::now();
let fft_res = fft(&a, false);
let t2 = Instant::now();
check(&dft_res, &fft_res);
let t3 = Instant::now();
let idft_res = norm(dft(&dft_res, true));
let t4 = Instant::now();
let ifft_res = norm(fft(&fft_res, true));
let t5 = Instant::now();
check(&a, &idft_res);
check(&a, &ifft_res);
check(&idft_res, &ifft_res);
[t1 - t0, t2 - t1, t4 - t3, t5 - t4]
}
fn gen_a() -> Vec<Complex> {
const FFT_SIZE: usize = 2usize.pow(12);
let mut rng = File::open("/dev/urandom").unwrap();
let mut tmp = [0; FFT_SIZE * 4];
rng.read_exact(&mut tmp).unwrap();
(0..FFT_SIZE)
.map(|i| {
let re = i16::from_le_bytes([tmp[i * 4], tmp[i * 4 + 1]]) as f64 / i16::MAX as f64;
let im = i16::from_le_bytes([tmp[i * 4 + 2], tmp[i * 4 + 3]]) as f64 / i16::MAX as f64;
Complex::new(re, im)
})
.collect()
}
fn check(a: &[Complex], b: &[Complex]) {
assert_eq!(a.len(), b.len());
let precision = 2.0f64.powi(-16);
for (c0, c1) in zip(a, b) {
assert!((c0.re - c1.re).abs() < precision);
assert!((c0.im - c1.im).abs() < precision);
}
}
fn norm(mut a: Vec<Complex>) -> Vec<Complex> {
let len = a.len() as f64;
a.iter_mut().for_each(|c| *c *= 1.0 / len);
a
}
fn dft(a: &[Complex], inv: bool) -> Vec<Complex> {
(0..a.len())
.map(|i| {
let w_i = if inv {
2.0 * PI * -(i as f64) / a.len() as f64
} else {
2.0 * PI * i as f64 / a.len() as f64
};
let mut y_i = Complex::default();
for (pow, a_i) in a.iter().enumerate() {
let w_i_pow = w_i * pow as f64;
y_i += *a_i * Complex::new(w_i_pow.cos(), w_i_pow.sin());
}
y_i
})
.collect()
}
fn fft(a: &[Complex], inv: bool) -> Vec<Complex> {
let a_len = a.len();
if a_len == 1 {
return vec![a[0]];
}
let mut a0 = Vec::with_capacity(a_len / 2);
let mut a1 = Vec::with_capacity(a_len / 2);
(0..a_len).for_each(|i| {
if i & 1 == 0 {
a0.push(a[i])
} else {
a1.push(a[i])
}
});
let y0 = fft(&a0, inv);
let y1 = fft(&a1, inv);
(0..a_len)
.map(|i| {
let angle = if inv {
2.0 * -(i as f64) * PI / a_len as f64
} else {
2.0 * i as f64 * PI / a_len as f64
};
let w_i = Complex::new(angle.cos(), angle.sin());
y0[i % (a_len / 2)] + w_i * y1[i % (a_len / 2)]
})
.collect()
}
#[derive(Clone, Copy, Default)]
struct Complex {
re: f64,
im: f64,
}
impl Complex {
fn new(re: f64, im: f64) -> Self {
Self { re, im }
}
}
impl std::ops::Mul for Complex {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self {
re: self.re * rhs.re - self.im * rhs.im,
im: self.re * rhs.im + self.im * rhs.re,
}
}
}
impl std::ops::Mul<f64> for Complex {
type Output = Self;
fn mul(self, rhs: f64) -> Self::Output {
Self {
re: self.re * rhs,
im: self.im * rhs,
}
}
}
impl std::ops::MulAssign<f64> for Complex {
fn mul_assign(&mut self, rhs: f64) {
*self = *self * rhs
}
}
impl std::ops::Add for Complex {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self {
re: self.re + rhs.re,
im: self.im + rhs.im,
}
}
}
impl std::ops::AddAssign for Complex {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment