Last active
August 29, 2015 14:22
-
-
Save bobbyali/004913a0456ef5db8c23 to your computer and use it in GitHub Desktop.
Continuous Naive Bayes Classifier in Java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The raw data in the csv file contains continuous values representing the following: | |
1st col: Number of hours since last nap | |
2nd col: Number of hours since last meal | |
3rd col: Number of toys present | |
4th col: Outcome - 1 means meltdown, 0 means no meltdown | |
The TestClassifier.java file executes the app, loads the csv file, and | |
passes the contents to the NaiveBayesContinuous.java file. | |
NaiveBayesContinuous then works out all the probabilities required | |
to make a prediction. | |
For more info on what's going on, see my blog post at | |
http://www.hacker-dad.com/how-to-predict-a-meltdown/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 | 8 | 8 | 1 | |
---|---|---|---|---|
6 | 2 | 7 | 1 | |
8 | 4 | 5 | 1 | |
7 | 7 | 4 | 1 | |
4 | 3 | 7 | 1 | |
9 | 9 | 0 | 1 | |
5 | 3 | 3 | 0 | |
3 | 7 | 2 | 0 | |
2 | 5 | 5 | 0 | |
5 | 3 | 3 | 0 | |
2 | 5 | 7 | 0 | |
7 | 8 | 1 | 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package naive_bayes; | |
import java.util.ArrayList; | |
import java.util.List; | |
public class NaiveBayesContinuous { | |
// lists containing training data | |
private List<Float> training_nap = new ArrayList<Float>(); | |
private List<Float> training_eat = new ArrayList<Float>(); | |
private List<Float> training_toy = new ArrayList<Float>(); | |
private List<Boolean> training_outcome = new ArrayList<Boolean>(); | |
// prior probabilities | |
public float p_meltdown, p_noMeltdown; | |
// pdf parameters (mean and variance) | |
public float mean_nap_meltdown, var_nap_meltdown; | |
public float mean_eat_meltdown, var_eat_meltdown; | |
public float mean_toy_meltdown, var_toy_meltdown; | |
public float mean_nap_noMeltdown, var_nap_noMeltdown; | |
public float mean_eat_noMeltdown, var_eat_noMeltdown; | |
public float mean_toy_noMeltdown, var_toy_noMeltdown; | |
public float mean_nap, var_nap; | |
public float mean_eat, var_eat; | |
public float mean_toy, var_toy; | |
// posterior probabilities | |
public float p_meltdown_data, p_noMeltdown_data; | |
public NaiveBayesContinuous(String[] data) { | |
for (String line: data) { | |
training_nap.add( (float) Character.getNumericValue( line.charAt(0) )); | |
training_eat.add( (float) Character.getNumericValue( line.charAt(2) )); | |
training_toy.add( (float) Character.getNumericValue( line.charAt(4) )); | |
training_outcome.add( convertCharToBoolean( line.charAt(6) ) ); | |
} | |
calcPriorProbabilities(); | |
calcPdfParameters(); | |
} | |
private void calcPriorProbabilities() { | |
float numMeltdown = 0, numNoMeltdown = 0; | |
for (Boolean b: this.training_outcome) { | |
if (b == true) { | |
numMeltdown++; | |
} | |
else { | |
numNoMeltdown++; | |
} | |
} | |
this.p_meltdown = numMeltdown / this.training_outcome.size(); | |
this.p_noMeltdown = numNoMeltdown / this.training_outcome.size(); | |
} | |
private void calcPdfParameters() { | |
List<Float> nap_meltdown = new ArrayList<Float>(), eat_meltdown = new ArrayList<Float>(), toy_meltdown = new ArrayList<Float>();; | |
List<Float> nap_noMeltdown = new ArrayList<Float>(), eat_noMeltdown = new ArrayList<Float>(), toy_noMeltdown = new ArrayList<Float>(); | |
for (int i = 0; i < this.training_outcome.size(); i++) { | |
if (this.training_outcome.get(i) == true) { | |
nap_meltdown.add( this.training_nap.get(i) ); | |
eat_meltdown.add( this.training_eat.get(i) ); | |
toy_meltdown.add( this.training_toy.get(i) ); | |
} else { | |
nap_noMeltdown.add( this.training_nap.get(i) ); | |
eat_noMeltdown.add( this.training_eat.get(i) ); | |
toy_noMeltdown.add( this.training_toy.get(i) ); | |
} | |
} | |
this.mean_nap_meltdown = calcMean(nap_meltdown); | |
this.mean_eat_meltdown = calcMean(eat_meltdown); | |
this.mean_toy_meltdown = calcMean(toy_meltdown); | |
this.mean_nap_noMeltdown = calcMean(nap_noMeltdown); | |
this.mean_eat_noMeltdown = calcMean(eat_noMeltdown); | |
this.mean_toy_noMeltdown = calcMean(toy_noMeltdown); | |
this.var_nap_meltdown = calcVariance(nap_meltdown, this.mean_nap_meltdown); | |
this.var_eat_meltdown = calcVariance(eat_meltdown, this.mean_eat_meltdown); | |
this.var_toy_meltdown = calcVariance(toy_meltdown, this.mean_toy_meltdown); | |
this.var_nap_noMeltdown = calcVariance(nap_noMeltdown, this.mean_nap_noMeltdown); | |
this.var_eat_noMeltdown = calcVariance(eat_noMeltdown, this.mean_eat_noMeltdown); | |
this.var_toy_noMeltdown = calcVariance(toy_noMeltdown, this.mean_toy_noMeltdown); | |
this.mean_nap = calcMean(this.training_nap); | |
this.mean_eat = calcMean(this.training_eat); | |
this.mean_toy = calcMean(this.training_toy); | |
this.var_nap = calcVariance(this.training_nap, this.mean_nap); | |
this.var_eat = calcVariance(this.training_eat, this.mean_eat); | |
this.var_toy = calcVariance(this.training_toy, this.mean_toy); | |
} | |
public void calcPosterior(float nap, float eat, float toy) { | |
float numerator_meltdown = calcGaussianConditionalProbability(nap, this.mean_nap_meltdown, this.var_nap_meltdown) | |
* calcGaussianConditionalProbability(eat, this.mean_eat_meltdown, this.var_eat_meltdown) | |
* calcGaussianConditionalProbability(toy, this.mean_toy_meltdown, this.var_toy_meltdown) | |
* this.p_meltdown; | |
float numerator_noMeltdown = calcGaussianConditionalProbability(nap, this.mean_nap_noMeltdown, this.var_nap_noMeltdown) | |
* calcGaussianConditionalProbability(eat, this.mean_eat_noMeltdown, this.var_eat_noMeltdown) | |
* calcGaussianConditionalProbability(toy, this.mean_toy_noMeltdown, this.var_toy_noMeltdown) | |
* this.p_meltdown; | |
float denominator = calcGaussianConditionalProbability(nap, this.mean_nap, this.var_nap) | |
* calcGaussianConditionalProbability(eat, this.mean_eat, this.var_eat) | |
* calcGaussianConditionalProbability(toy, this.mean_toy, this.var_toy); | |
this.p_meltdown_data = numerator_meltdown / denominator; | |
this.p_noMeltdown_data = numerator_noMeltdown / denominator; | |
printPosteriors(); | |
} | |
public void printPosteriors() { | |
System.out.println("Posteriors:"); | |
System.out.println("p(Breakdown|Data) = " + this.p_meltdown_data); | |
System.out.println("p(No Breakdown|Data) = " + this.p_noMeltdown_data); | |
System.out.println("Sum of posteriors = " + (this.p_meltdown_data + this.p_noMeltdown_data)); | |
if (this.p_meltdown_data > this.p_noMeltdown_data) { | |
System.out.println("Breakdown is more likely."); | |
} else if (this.p_meltdown_data < this.p_noMeltdown_data) { | |
System.out.println("No Breakdown is more likely."); | |
} else { | |
System.out.println("Equal chance of breakdown vs no breakdown."); | |
} | |
System.out.println(" "); | |
} | |
private float calcGaussianConditionalProbability(float v, float mean, float variance) { | |
float term1 = (float) (1 / (Math.sqrt(2 * Math.PI * variance))); | |
float term2 = (float) -(Math.pow(v-mean,2)) / (2 * variance); | |
return (float) (term1 * Math.exp(term2)); | |
} | |
private float calcMean(List<Float> data) { | |
float total = 0; | |
for (float i : data) { | |
total += i; | |
} | |
return total / data.size(); | |
} | |
private float calcVariance(List<Float> data, float mean) { | |
float ssds = 0; // sum of squared differences | |
for (float i : data) { | |
ssds += Math.pow(i - mean, 2); | |
} | |
return ssds / (data.size() - 1); | |
} | |
private Boolean convertCharToBoolean(char c) { | |
if (c == '1') { | |
return true; | |
} else if (c == '0') { | |
return false; | |
} else { | |
return null; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package naive_bayes; | |
import java.io.IOException; | |
import java.nio.charset.Charset; | |
import java.nio.file.Files; | |
import java.nio.file.Paths; | |
import java.util.List; | |
public class TestClassifier { | |
public static void main(String[] args) { | |
String fileName = "./continuous.csv"; | |
String[] data; | |
try { | |
List<String> lines = Files.readAllLines(Paths.get(fileName), Charset.defaultCharset()); | |
data = lines.toArray(new String[0]); | |
NaiveBayesContinuous dataProcessor = new NaiveBayesContinuous(data); | |
dataProcessor.calcPosterior(6, 5, 3); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment