Last active
May 13, 2018 16:17
-
-
Save mzmmoazam/1836ebbe07c469ef1e1a10639013429d to your computer and use it in GitHub Desktop.
A simple naive bayes implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv, random | |
class NaiveBayes(object): | |
def __init__(self, filename, split_ratio): | |
''' | |
:param filename: a csv filename with absolute or full path | |
:param split_ratio: test to train ratio | |
''' | |
self.train, self.test = self._splitDataset(self._loadCsv(filename), split_ratio) | |
self.labels = set(self.train[6]) | |
def _loadCsv(self, filename): | |
''' | |
:param filename: filename passed from the constructor of the class | |
:return: returns the datsaet | |
''' | |
lines = csv.reader(open(filename, "r")) | |
dataset = list(lines) | |
for _ in dataset: | |
for i in [29, 28, 17, 0]: | |
del _[i] | |
for i in range(len(dataset)): | |
dataset[i] = [float(x) for x in dataset[i]] | |
return dataset | |
def _splitDataset(self, dataset, splitRatio): | |
''' | |
:param dataset: the dataset provided bt the _load_csv function | |
:param splitRatio: split ratio as in the constructor | |
:return: training and testing data | |
''' | |
trainSize = int(len(dataset) * splitRatio) | |
trainSet = [] | |
copy = list(dataset) | |
while len(trainSet) < trainSize: | |
index = random.randrange(len(copy)) | |
trainSet.append(copy.pop(index)) | |
return [trainSet, copy] | |
def getAccuracy(self): | |
''' | |
:return: returns accuracy of the classifier | |
''' | |
testSet, predictions = self.test, self.new_predictions | |
# print(predictions) | |
correct = 0 | |
for x in range(len(testSet)): | |
if testSet[x][6] == predictions[x]: | |
correct += 1 | |
return (correct / float(len(testSet))) * 100.0 | |
def naive_bayes(self): | |
''' | |
naive bayes as the name suggests (for categorical values) | |
:return: returns null | |
''' | |
self.new_predictions = [] | |
for i in range(len(self.test)): | |
_label = '' | |
val = -1 | |
if i == 6: | |
continue | |
for label in self.labels: | |
prob_class = self.prob_class(label) | |
tot_prob = 0 | |
for index in range(len(self.test[i])): | |
tot_prob *= self.prob(self.test[i][index], index, label) # p(E|C) = (p(E) * P(C|E)) / P(C) | |
if val < tot_prob * prob_class: | |
val = tot_prob * prob_class | |
_label = label | |
self.new_predictions.append(_label) | |
def prob(self, attribute, attribute_index, label): | |
c = 0 | |
for row in self.train: | |
if row[attribute_index] == attribute and row[6] == label: | |
c += 1 | |
return c / len(self.train) | |
def prob_class(self, label): | |
c = 0 | |
for i in self.train: | |
if i[6] == label: | |
c += 1 | |
return c / len(self.train) | |
if __name__ == '__main__': | |
clf = NaiveBayes(filename='flag.data.csv', split_ratio=0.67) | |
clf.naive_bayes() | |
print(clf.getAccuracy()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment