Last active
August 27, 2019 12:31
-
-
Save sudo-ben/041e4772813644927aea201a6fa1c0b0 to your computer and use it in GitHub Desktop.
Naive bayes classifier Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Author Ben McDonald. Adapted from python scikit-learn | |
/// | |
/// Naive bayes is a machine learning classifier. It uses bayes formula to classify data | |
/// based on how well data fits into normal distributions modeled from training data | |
use rulinalg::matrix::BaseMatrix; | |
use rulinalg::matrix::BaseMatrixMut; | |
use rulinalg::matrix::{Axes, Matrix}; | |
use std::collections::HashMap; | |
use std::error::Error; | |
use std::f64::consts::PI; | |
// Trained model for classifier | |
pub struct GaussianBayesModel { | |
// all possible outputs | |
pub categories: Vec<usize>, | |
pub means_matrix: Matrix<f64>, | |
pub variances_matrix: Matrix<f64>, | |
} | |
#[derive(Debug, Display)] | |
#[display(fmt = "EmptyInputError: Training data is empty")] | |
pub struct EmptyInputError; | |
impl Error for EmptyInputError {} | |
#[derive(Debug, Display)] | |
#[display( | |
fmt = "MismatchedFeaturesError: Test data length does not match number of features in suppied model." | |
)] | |
pub struct MismatchedFeaturesError; | |
impl Error for MismatchedFeaturesError {} | |
/// | |
/// Summarise the dataset per class using normal distributions | |
/// training_x_y - array of tuple trainging data. len = number of samples, first tuple elements are observations | |
// Example | |
// X - visit_time, clicks, pages, keypress | |
// y - return_visits | |
/// and 2nd is the class | |
pub fn build_model( | |
training_x_y: Vec<(Vec<f64>, usize)>, | |
) -> Result<GaussianBayesModel, Box<dyn Error + 'static>> { | |
// take for x row and count intut numbers | |
let num_features: usize = training_x_y | |
.first() | |
.ok_or_else(|| EmptyInputError {})? | |
.0 | |
.len(); | |
// Collect all training data rows into their outputs | |
let mut categories_map: HashMap<usize, Vec<f64>> = HashMap::new(); | |
for x in training_x_y.iter() { | |
categories_map | |
.entry(x.1) | |
.or_insert(Vec::new()) | |
.extend(x.0.clone()); | |
} | |
// Count num of unique outputs | |
let num_categories = categories_map.len(); | |
// Start building model | |
let mut categories: Vec<usize> = Vec::with_capacity(num_categories); | |
let mut means: Vec<f64> = Vec::with_capacity(num_categories * num_features); | |
let mut variances: Vec<f64> = Vec::with_capacity(num_categories * num_features); | |
for (category, category_grouped) in categories_map.iter() { | |
categories.push(category.clone()); | |
let num_samples_in_category: usize = category_grouped.len() / num_features; | |
// Transform vectors into matrices | |
let cat_matrix = Matrix::new( | |
num_samples_in_category, | |
num_features, | |
category_grouped.clone(), | |
); | |
// Calc mean and variance using vectorized arrays | |
let mean = cat_matrix.mean(Axes::Row); | |
// cat_matrix.variance can fail if no variance in training feature | |
let variance = cat_matrix.variance(Axes::Row)?; | |
means.extend(mean.into_vec()); | |
variances.extend(variance.into_vec()); | |
} | |
// Transform vectors back into matrices | |
let means_matrix = Matrix::new(num_features, num_categories, means); | |
let variances_matrix = Matrix::new(num_features, num_categories, variances); | |
Ok(GaussianBayesModel { | |
categories, | |
means_matrix, | |
variances_matrix, | |
}) | |
} | |
/// | |
/// Make a prediction | |
/// test_x - array where len = num_features | |
/// model - model trained with fn build_model | |
pub fn fit_model( | |
test_x: Vec<f64>, | |
model: &GaussianBayesModel, | |
) -> Result<usize, Box<dyn Error + 'static>> { | |
let num_categories: usize = model.means_matrix.cols(); | |
let num_features: usize = model.means_matrix.rows(); | |
if test_x.len() != num_features { | |
return Err((MismatchedFeaturesError {}).into()); | |
} | |
// Collect a row of test data for each category | |
let test_x_repeated: Vec<f64> = test_x | |
.iter() | |
.cycle() | |
.take(test_x.len() * num_categories) | |
.cloned() | |
.collect::<Vec<f64>>(); | |
let mut test_x_matrix: Matrix<f64> = Matrix::new(num_features, num_categories, test_x_repeated); | |
// posterior = prior is 1/number of classes | |
// improvement (prior for a given class) = (number of samples in the class) / (total number of samples) | |
let mean_diff_sqrt: Matrix<f64> = | |
(test_x_matrix.as_mut_slice() - model.means_matrix.clone()).apply(&|x| -(x * x)); | |
let exponent = mean_diff_sqrt | |
.elediv(&model.variances_matrix.clone().apply(&|x| x * 2.0)) | |
.apply(&|x| x.exp()); | |
let sqrt_2pi: f64 = ((PI * 2.0) as f64).sqrt(); | |
let inner_divide = model | |
.variances_matrix | |
.clone() | |
.apply(&|x| 1.0 / (sqrt_2pi * x.sqrt())); | |
let unreduced_probabilies = exponent.elemul(&inner_divide); | |
// Vector of the probabilities of the test data fitting a catagory. | |
let unreduced_probabilies_data = unreduced_probabilies.data(); | |
let probabilities_each_catagory: Vec<f64> = (0..num_categories) | |
.map(|n| { | |
let i_start = n * num_features; | |
let feat_slice = &unreduced_probabilies_data[i_start..(i_start + num_features)]; | |
feat_slice.iter().product() | |
}) | |
.collect(); | |
// Find the highest probability. | |
// argmax returns (index_of_max_value, max_value). Take first element, the index | |
let best_fit_index = probabilities_each_catagory | |
.into_iter() | |
.enumerate() | |
.max_by(|x, y| x.1.partial_cmp(&y.1).expect("Tried to compare a NaN")) | |
.unwrap() | |
.0; | |
Ok(model.categories[best_fit_index]) | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[test] | |
#[should_panic] | |
fn test_panic_on_empty_input() { | |
build_model(vec![]).unwrap();; | |
} | |
#[test] | |
#[should_panic] | |
fn test_panic_on_one_input() { | |
build_model(vec![(vec![11.0, 2.0], 1)]).unwrap(); | |
} | |
#[test] | |
#[should_panic] | |
fn test_panic_on_mismatched_features() { | |
build_model(vec![(vec![11.0, 2.0], 1), (vec![10.0], 1)]).unwrap(); | |
} | |
#[test] | |
#[should_panic] | |
fn test_panic_no_variance() { | |
build_model(vec![ | |
(vec![11.0, 2.0], 1), | |
(vec![10.0, 1.0], 1), | |
(vec![11.0, 2.0], 2), | |
]) | |
.unwrap(); | |
} | |
#[test] | |
fn test_model3() { | |
let x_input = vec![ | |
(vec![-1.0, -1.0], 1), | |
(vec![-2.0, -1.0], 1), | |
(vec![-3.0, -2.0], 1), | |
(vec![1.0, 1.0], 2), | |
(vec![2.0, 1.0], 2), | |
(vec![3.0, 2.0], 2), | |
]; | |
let model: GaussianBayesModel = build_model(x_input).unwrap(); | |
assert_eq!(fit_model(vec![-0.8, -1.0], &model).unwrap(), 1); | |
} | |
#[test] | |
fn test_model() { | |
let x_input = vec![ | |
(vec![-1.0, -1.0], 1), | |
(vec![-2.0, -1.0], 1), | |
(vec![-3.0, -2.0], 1), | |
(vec![1.0, 1.0], 2), | |
(vec![2.0, 1.0], 2), | |
(vec![3.0, 2.0], 2), | |
]; | |
let model: GaussianBayesModel = build_model(x_input).unwrap(); | |
assert_eq!(fit_model(vec![-0.8, -1.0], &model).unwrap(), 1); | |
} | |
#[test] | |
fn test_model2() { | |
let x_input = vec![ | |
(vec![1.0, 20.0], 1), | |
(vec![20.0, 210.0], 0), | |
(vec![3.0, 22.0], 1), | |
(vec![40.0, 220.0], 0), | |
(vec![6.0, 10.0], 2), | |
(vec![7.0, 11.0], 2), | |
(vec![7.0, 11.0], 2), | |
]; | |
// means_matrix | |
// ([[100.00000821, 25.00000821], | |
// [ 1.00000821, 1.00000821], | |
// [ 0.22223043, 0.22223043]]) | |
// variances_matrix | |
// ([[ 30. , 215. ], | |
// [ 2. , 21. ], | |
// [ 6.66666667, 10.66666667]]) | |
let model: GaussianBayesModel = build_model(x_input).unwrap(); | |
assert_eq!(fit_model(vec![1.0, 20.0], &model).unwrap(), 1); | |
assert_eq!(fit_model(vec![20.0, 210.0], &model).unwrap(), 0); | |
assert_eq!(fit_model(vec![25.0, 225.0], &model).unwrap(), 0); | |
assert_eq!(fit_model(vec![30.0, 215.0], &model).unwrap(), 0); | |
assert_eq!(fit_model(vec![3.0, 22.0], &model).unwrap(), 1); | |
assert_eq!(fit_model(vec![6.0, 10.0], &model).unwrap(), 2); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment