Skip to content

Instantly share code, notes, and snippets.

@sudo-ben
Last active August 27, 2019 12:31
Show Gist options
  • Save sudo-ben/041e4772813644927aea201a6fa1c0b0 to your computer and use it in GitHub Desktop.
Save sudo-ben/041e4772813644927aea201a6fa1c0b0 to your computer and use it in GitHub Desktop.
Naive bayes classifier Rust
/// Author Ben McDonald. Adapted from python scikit-learn
///
/// Naive bayes is a machine learning classifier. It uses bayes formula to classify data
/// based on how well data fits into normal distributions modeled from training data
use rulinalg::matrix::BaseMatrix;
use rulinalg::matrix::BaseMatrixMut;
use rulinalg::matrix::{Axes, Matrix};
use std::collections::HashMap;
use std::error::Error;
use std::f64::consts::PI;
// Trained model for classifier
pub struct GaussianBayesModel {
// all possible outputs
pub categories: Vec<usize>,
pub means_matrix: Matrix<f64>,
pub variances_matrix: Matrix<f64>,
}
#[derive(Debug, Display)]
#[display(fmt = "EmptyInputError: Training data is empty")]
pub struct EmptyInputError;
impl Error for EmptyInputError {}
#[derive(Debug, Display)]
#[display(
fmt = "MismatchedFeaturesError: Test data length does not match number of features in suppied model."
)]
pub struct MismatchedFeaturesError;
impl Error for MismatchedFeaturesError {}
///
/// Summarise the dataset per class using normal distributions
/// training_x_y - array of tuple trainging data. len = number of samples, first tuple elements are observations
// Example
// X - visit_time, clicks, pages, keypress
// y - return_visits
/// and 2nd is the class
pub fn build_model(
training_x_y: Vec<(Vec<f64>, usize)>,
) -> Result<GaussianBayesModel, Box<dyn Error + 'static>> {
// take for x row and count intut numbers
let num_features: usize = training_x_y
.first()
.ok_or_else(|| EmptyInputError {})?
.0
.len();
// Collect all training data rows into their outputs
let mut categories_map: HashMap<usize, Vec<f64>> = HashMap::new();
for x in training_x_y.iter() {
categories_map
.entry(x.1)
.or_insert(Vec::new())
.extend(x.0.clone());
}
// Count num of unique outputs
let num_categories = categories_map.len();
// Start building model
let mut categories: Vec<usize> = Vec::with_capacity(num_categories);
let mut means: Vec<f64> = Vec::with_capacity(num_categories * num_features);
let mut variances: Vec<f64> = Vec::with_capacity(num_categories * num_features);
for (category, category_grouped) in categories_map.iter() {
categories.push(category.clone());
let num_samples_in_category: usize = category_grouped.len() / num_features;
// Transform vectors into matrices
let cat_matrix = Matrix::new(
num_samples_in_category,
num_features,
category_grouped.clone(),
);
// Calc mean and variance using vectorized arrays
let mean = cat_matrix.mean(Axes::Row);
// cat_matrix.variance can fail if no variance in training feature
let variance = cat_matrix.variance(Axes::Row)?;
means.extend(mean.into_vec());
variances.extend(variance.into_vec());
}
// Transform vectors back into matrices
let means_matrix = Matrix::new(num_features, num_categories, means);
let variances_matrix = Matrix::new(num_features, num_categories, variances);
Ok(GaussianBayesModel {
categories,
means_matrix,
variances_matrix,
})
}
///
/// Make a prediction
/// test_x - array where len = num_features
/// model - model trained with fn build_model
pub fn fit_model(
test_x: Vec<f64>,
model: &GaussianBayesModel,
) -> Result<usize, Box<dyn Error + 'static>> {
let num_categories: usize = model.means_matrix.cols();
let num_features: usize = model.means_matrix.rows();
if test_x.len() != num_features {
return Err((MismatchedFeaturesError {}).into());
}
// Collect a row of test data for each category
let test_x_repeated: Vec<f64> = test_x
.iter()
.cycle()
.take(test_x.len() * num_categories)
.cloned()
.collect::<Vec<f64>>();
let mut test_x_matrix: Matrix<f64> = Matrix::new(num_features, num_categories, test_x_repeated);
// posterior = prior is 1/number of classes
// improvement (prior for a given class) = (number of samples in the class) / (total number of samples)
let mean_diff_sqrt: Matrix<f64> =
(test_x_matrix.as_mut_slice() - model.means_matrix.clone()).apply(&|x| -(x * x));
let exponent = mean_diff_sqrt
.elediv(&model.variances_matrix.clone().apply(&|x| x * 2.0))
.apply(&|x| x.exp());
let sqrt_2pi: f64 = ((PI * 2.0) as f64).sqrt();
let inner_divide = model
.variances_matrix
.clone()
.apply(&|x| 1.0 / (sqrt_2pi * x.sqrt()));
let unreduced_probabilies = exponent.elemul(&inner_divide);
// Vector of the probabilities of the test data fitting a catagory.
let unreduced_probabilies_data = unreduced_probabilies.data();
let probabilities_each_catagory: Vec<f64> = (0..num_categories)
.map(|n| {
let i_start = n * num_features;
let feat_slice = &unreduced_probabilies_data[i_start..(i_start + num_features)];
feat_slice.iter().product()
})
.collect();
// Find the highest probability.
// argmax returns (index_of_max_value, max_value). Take first element, the index
let best_fit_index = probabilities_each_catagory
.into_iter()
.enumerate()
.max_by(|x, y| x.1.partial_cmp(&y.1).expect("Tried to compare a NaN"))
.unwrap()
.0;
Ok(model.categories[best_fit_index])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn test_panic_on_empty_input() {
build_model(vec![]).unwrap();;
}
#[test]
#[should_panic]
fn test_panic_on_one_input() {
build_model(vec![(vec![11.0, 2.0], 1)]).unwrap();
}
#[test]
#[should_panic]
fn test_panic_on_mismatched_features() {
build_model(vec![(vec![11.0, 2.0], 1), (vec![10.0], 1)]).unwrap();
}
#[test]
#[should_panic]
fn test_panic_no_variance() {
build_model(vec![
(vec![11.0, 2.0], 1),
(vec![10.0, 1.0], 1),
(vec![11.0, 2.0], 2),
])
.unwrap();
}
#[test]
fn test_model3() {
let x_input = vec![
(vec![-1.0, -1.0], 1),
(vec![-2.0, -1.0], 1),
(vec![-3.0, -2.0], 1),
(vec![1.0, 1.0], 2),
(vec![2.0, 1.0], 2),
(vec![3.0, 2.0], 2),
];
let model: GaussianBayesModel = build_model(x_input).unwrap();
assert_eq!(fit_model(vec![-0.8, -1.0], &model).unwrap(), 1);
}
#[test]
fn test_model() {
let x_input = vec![
(vec![-1.0, -1.0], 1),
(vec![-2.0, -1.0], 1),
(vec![-3.0, -2.0], 1),
(vec![1.0, 1.0], 2),
(vec![2.0, 1.0], 2),
(vec![3.0, 2.0], 2),
];
let model: GaussianBayesModel = build_model(x_input).unwrap();
assert_eq!(fit_model(vec![-0.8, -1.0], &model).unwrap(), 1);
}
#[test]
fn test_model2() {
let x_input = vec![
(vec![1.0, 20.0], 1),
(vec![20.0, 210.0], 0),
(vec![3.0, 22.0], 1),
(vec![40.0, 220.0], 0),
(vec![6.0, 10.0], 2),
(vec![7.0, 11.0], 2),
(vec![7.0, 11.0], 2),
];
// means_matrix
// ([[100.00000821, 25.00000821],
// [ 1.00000821, 1.00000821],
// [ 0.22223043, 0.22223043]])
// variances_matrix
// ([[ 30. , 215. ],
// [ 2. , 21. ],
// [ 6.66666667, 10.66666667]])
let model: GaussianBayesModel = build_model(x_input).unwrap();
assert_eq!(fit_model(vec![1.0, 20.0], &model).unwrap(), 1);
assert_eq!(fit_model(vec![20.0, 210.0], &model).unwrap(), 0);
assert_eq!(fit_model(vec![25.0, 225.0], &model).unwrap(), 0);
assert_eq!(fit_model(vec![30.0, 215.0], &model).unwrap(), 0);
assert_eq!(fit_model(vec![3.0, 22.0], &model).unwrap(), 1);
assert_eq!(fit_model(vec![6.0, 10.0], &model).unwrap(), 2);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment