Created
October 26, 2020 13:26
-
-
Save sXakil/73898b40fc8ed48b85a7a313599d6aea to your computer and use it in GitHub Desktop.
Naive Bayes Classifier Implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
testData = "Type Long NotLong Sweet NotSweet Yellow NotYellow Total\nBanana 400 100 350 150 450 50 500\nOrange 0 300 150 150 300 0 300\nOther 100 100 150 50 50 150 200\nTotal 500 500 650 350 800 200 1000" | |
file = open("data.txt", "w+") | |
file.write(testData) | |
file.close() | |
file = open("data.txt", "r") | |
lines = file.readlines() | |
file.close() | |
types = ['long', 'notlong', 'sweet', 'notsweet', 'yellow', 'notyellow', 'total'] | |
dataset = {} | |
for line in lines[1:]: | |
data = line.replace('\n', '').split(' ') | |
dataset[data[0].lower()] = {} | |
d = data[1:] | |
for idx, type in enumerate(types): | |
dataset[data[0].lower()][type] = int(d[idx]) | |
def calcP(obj, idx, div=None): | |
if not div: | |
div = dataset[obj]['total'] | |
return dataset[obj][idx] / div | |
def naiveBayes(pAB, pA, pB): | |
nB = (pAB * pB) / pA | |
return nB if nB <= 1.0 else 1.0 | |
pTypes = {} | |
for type in types: | |
pTypes[type] = calcP('total', type) | |
pFruits = {} | |
for fruit in dataset: | |
pFruits[fruit] = calcP(fruit, 'total', dataset['total'][type]) | |
pFruitType = {} | |
for fruit in dataset: | |
pFruitType[fruit] = {} | |
for idx, type in enumerate(types): | |
pFruitType[fruit][type] = calcP(fruit, type, dataset['total'][type]) | |
pTypeFruit = {} | |
for type in types: | |
pTypeFruit[type] = {} | |
for fruit in dataset: | |
pTypeFruit[type][fruit] = naiveBayes(pFruitType[fruit][type], pFruits[fruit], pTypes[type]) | |
pFruitProbability = {} | |
classifyBy = ['sweet', 'notlong', 'yellow'] # [set(input("Enter types of the unknown fruit: ").split(' ')] | |
for fruit in dataset: | |
pFruitProbability[fruit] = 1 | |
for cls in classifyBy: | |
pFruitProbability[fruit] *= pTypeFruit[cls][fruit] # * pTypeFruit['sweet'][fruit] * pTypeFruit['yellow'][fruit] | |
decession = max(pFruitProbability, key=pFruitProbability.get) | |
print(f"The class of an unknown fruit that is {', '.join(classifyBy)} is '{decession}'") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment