Last active
November 7, 2017 13:56
-
-
Save hoenirvili/081383c0b18bc9c652b82f42130aa683 to your computer and use it in GitHub Desktop.
id3 algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
usoara | mirositoare | arepete | neteda | comestibila | |
---|---|---|---|---|---|
1 | 0 | 0 | 0 | 1 | |
1 | 0 | 1 | 0 | 1 | |
0 | 1 | 0 | 1 | 1 | |
0 | 0 | 0 | 1 | 0 | |
1 | 1 | 1 | 0 | 0 | |
1 | 0 | 1 | 1 | 0 | |
1 | 0 | 0 | 1 | 0 | |
0 | 1 | 0 | 0 | 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import pandas as pd | |
import sklearn | |
from sklearn.datasets import load_iris | |
from sklearn import tree | |
import numpy as np | |
import graphviz | |
def retrieve_target_names(dataset): | |
target = dataset['comestibila'].sort_values().values | |
return (target, np.array(['necomestibila','comestibila'])) | |
def main(): | |
mushrooms = pd.read_csv('data.csv') | |
feature_names = mushrooms.columns.tolist()[:4] | |
target, target_names = retrieve_target_names(mushrooms) | |
data = mushrooms[feature_names].values | |
classifier = tree.DecisionTreeClassifier(criterion='entropy') | |
classifier.fit(data, target) | |
dot_data = tree.export_graphviz(classifier, out_file=None, | |
feature_names=feature_names, | |
class_names=target_names, | |
filled=True, rounded=True, | |
special_characters=True) | |
graph = graphviz.Source(dot_data) | |
graph.render("mushrooms") | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import csv | |
import sys | |
import math | |
import copy | |
def entropy(partition): | |
""" | |
partition : list [2,5] | |
""" | |
total = sum(partition) | |
entropy = 0 | |
# compute the entropy of all elements | |
for element in partition: | |
if element == 0: | |
continue | |
p = element/total # compute the probability | |
entropy -= (p * math.log2(p)) | |
return entropy | |
def ig(partitions): | |
""" | |
partitions: list [ [1,2], [2, 5] ] | |
""" | |
n = len(partitions) | |
# make the root partition of all all the children | |
root = [0 for l in range(0, n)] | |
for column, partition in enumerate(partitions): | |
for row, _ in enumerate(partition): | |
root[column] += partitions[row][column] | |
# compute the root entropy | |
root_entropy = entropy(root) | |
# get the number of instances of the decision stamp | |
instances = sum(root) | |
avg_entropy = 0 | |
for partition in partitions: | |
part_sum = sum(partition) | |
avg_entropy = avg_entropy + ((part_sum/instances) * entropy(partition)) | |
return root_entropy - avg_entropy | |
def dmap(attributes, data): | |
""" | |
atributes: ['usoara', 'mirositoare', 'arepete', 'neteda', 'comestibila'] | |
data: | |
['1', '0', '0', '0', '1'] | |
['1', '0', '1', '0', '1'] | |
['0', '1', '0', '1', '1'] | |
['0', '0', '0', '1', '0'] | |
['1', '1', '1', '0', '0'] | |
['1', '0', '1', '1', '0'] | |
['1', '0', '0', '1', '0'] | |
['0', '1', '0', '0', '0'] | |
res: key=>value | |
usoara: ['1', '1', '0', '0', '1', '1', '1', '0'] | |
mirositoare: ['0', '0', '1', '0', '1', '0', '0', '1'] | |
arepete: ['0', '1', '0', '0', '1', '1', '0', '0'] | |
neteda: ['0', '0', '1', '1', '0', '1', '1', '0'] | |
comestibila: ['1', '1', '1', '0', '0', '0', '0', '0'] | |
""" | |
m = len(data[0]) | |
n = len(data) | |
columns = [] | |
for j in range(0, m): | |
column = [] | |
for i in range(0, n): | |
column.append(data[i][j]) | |
columns.append(column) | |
return dict(zip(attributes, columns)) | |
def decision_stamps(dmap): | |
""" | |
dmap: | |
usoara: ['1', '1', '0', '0', '1', '1', '1', '0'] | |
mirositoare: ['0', '0', '1', '0', '1', '0', '0', '1'] | |
arepete: ['0', '1', '0', '0', '1', '1', '0', '0'] | |
neteda: ['0', '0', '1', '1', '0', '1', '1', '0'] | |
comestibila: ['1', '1', '1', '0', '0', '0', '0', '0'] | |
res: | |
{ | |
'usoara': { | |
'0': {'0': 2, '1': 1}, | |
'1': {'0': 3, '1': 2} | |
}, | |
'mirositoare': { | |
'0': {'0': 3, '1': 2}, | |
'1': {'0': 2, '1': 1} | |
}, | |
'arepete': { | |
'0': {'0': 3, '1': 2}, | |
'1': {'0': 2, '1': 1} | |
}, | |
'neteda': { | |
'0': {'0': 2, '1': 2}, | |
'1': {'0': 3, '1': 1} | |
} | |
} | |
""" | |
ds = {} | |
decision_key = [*dmap][-1] | |
decisions = dmap[decision_key] | |
for key, value in dmap.items(): | |
if key == decision_key: | |
continue | |
ds[key] = {} | |
unique_values = list(set(value)) | |
for uv in unique_values: | |
subset = [] | |
for i, v in enumerate(value): | |
if v == uv: | |
subset += [decisions[i]] | |
unique_decision = list(set(decisions)) | |
uv_l = [] | |
for d in unique_decision: | |
uv_l.append(subset.count(d)) | |
ds[key][uv] = dict(zip(unique_decision, uv_l)) | |
return ds | |
def all_partitions(decision_stamps): | |
""" | |
decision_stamps: | |
{ | |
'usoara': { | |
'0': {'0': 2, '1': 1}, | |
'1': {'0': 3, '1': 2} | |
}, | |
'mirositoare': { | |
'0': {'0': 3, '1': 2}, | |
'1': {'0': 2, '1': 1} | |
}, | |
'arepete': { | |
'0': {'0': 3, '1': 2}, | |
'1': {'0': 2, '1': 1} | |
}, | |
'neteda': { | |
'0': {'0': 2, '1': 2}, | |
'1': {'0': 3, '1': 1} | |
} | |
} | |
res: | |
{ | |
'usoara': [[2, 3], [1, 2]], | |
'mirositoare': [[1, 2], [2, 3]], | |
'arepete': [[1, 2], [2, 3]], | |
'neteda': [[1, 3], [2, 2]] | |
} | |
""" | |
partitions = {} | |
for key, decision_stamp in decision_stamps.items(): | |
partitions[key] = [] | |
for partition in decision_stamp.values(): | |
p = [*partition.values()] | |
partitions[key].append(p) | |
return partitions | |
def best_attribute(partitions): | |
""" | |
partitions: | |
{ | |
'usoara': [[2, 3], [1, 2]], | |
'mirositoare': [[1, 2], [2, 3]], | |
'arepete': [[1, 2], [2, 3]], | |
'neteda': [[1, 3], [2, 2]] | |
} | |
res: | |
neteda | |
""" | |
if len(partitions.values()) == 2: | |
return [*partitions.keys()][0] | |
max_ig = 0 | |
max_ig_name = '' | |
for key, values in partitions.items(): | |
a = ig(values) | |
if a > max_ig: | |
max_ig = a | |
max_ig_name = key | |
return max_ig_name | |
def filter_data(data, attributes, attribute, attribute_value): | |
""" | |
using data, attributes, target_attribute and target_attribute_value | |
filter based on traget_attribute_value the data and return it's subset | |
""" | |
s = copy.deepcopy(data) | |
subset = None | |
index = attributes.index(attribute) | |
subset = [] | |
for row in s: | |
if row[index] == attribute_value: | |
subset.append(row) | |
s = subset | |
return subset | |
def remove_column(data, col): | |
""" | |
remove the hole column in our newly copied data | |
""" | |
d = copy.deepcopy(data) | |
for row in d: | |
row.pop(col) | |
return d | |
def pick_best_attribute(attributes, data): | |
""" | |
for the attributes and data given compute pick, select | |
the best attribute that has the max information gain | |
""" | |
# if we are dealing with just tow attributes | |
# this means we have just only one attribute to classify | |
# and we return that exact attribute | |
if len(attributes) == 2: | |
attribute = attributes[0] | |
dmapp = dmap(attribute, data) | |
ds = decision_stamps(dmapp) | |
# return attribute and his decision stamp | |
return (attribute, ds) | |
# make a mapping out of all attributes and data | |
dmapp = dmap(attributes, data) | |
# create decision stamps of the mapping | |
ds = decision_stamps(dmapp) | |
# return just the partitions of the decision stamps | |
parts = all_partitions(ds) | |
# for all the partitions compute the best attribute | |
attribute = best_attribute(parts) | |
# return the best attribute and his decision stamp | |
return (attribute, ds) | |
def pick_subset(data, attributes, attribute, vertice): | |
""" | |
for the given data, attributes, target_attribute, and his | |
vertice | |
pick the subset / instances that has includes the target_attribute | |
vertice and return the subset and the corresponding attribute subset as a tuple | |
""" | |
# filter the data based on the attribute and vertice | |
data = filter_data(data, attributes, attribute, vertice) | |
# because attributes is a list, make a copy and preserve the original list | |
attr = copy.copy(attributes) | |
# compute the index and remove the attribute | |
# from the list of attributes | |
idx = attr.index(attribute) # compute the index of the attribute in attributes | |
attr.remove(attribute) # remove the attribute from the list | |
# after the filter_data process we should also | |
# remove the column that contains the vertice value | |
data = remove_column(data, idx) | |
# return the subset pair data | |
return (data, attr) | |
class Node: | |
""" | |
Node represents a single decision node in our Id3 | |
tree. This will hold the attribute name, his neighbours and | |
his decision stamp | |
""" | |
def __init__(self, attribute=None, stamp=None): | |
if stamp == None and attribute == None: | |
self.attribute = None | |
self.stamp = None | |
self.neighbours = None | |
return | |
self.attribute = attribute | |
self.stamp = stamp | |
self.neighbours = {} | |
# for every stamp we have we should make now the decisions | |
# of every vertices and if we can't make the decision we should | |
# add in self.neighbours a None value | |
for vertice, s in self.stamp.items(): | |
self.neighbours[vertice] = self._decision(s) | |
def __repr__(self): | |
return self.__str__() | |
@property | |
def vertices(self): | |
return [*self.neighbours] | |
def __str__(self): | |
message = '' | |
message += '[NB attribute = {}, '.format(self.attribute) | |
message += 'stamp = {}, '.format(self.stamp) | |
for key, value in self.neighbours.items(): | |
message += 'vertice:{} => decision|node {} '.format(key, value) | |
message += 'NE]' | |
return message | |
def _decision(self, s): | |
""" | |
for every decision_stamp value | |
check if we can make a decision and classify our | |
examples or we need to mark it as unknown for know | |
""" | |
aparitions = 0 | |
dec = None | |
dict_values = s.values() | |
values = [*dict_values] | |
for v in values: | |
if v > 0: | |
dec = v | |
aparitions = aparitions + 1 | |
if aparitions > 1: | |
# this means we don't have a partition that classify | |
# our instances perfectly | |
return None | |
keys = [*s.keys()] # get all keys of the dict | |
return keys[values.index(dec)] # get the value key in the dict | |
def empty(self): | |
return (self.stamp == None and | |
self.attribute == None and | |
self.neighbours == None) | |
def push(self, node): | |
if self.empty(): | |
self.stamp = node.stamp | |
self.attribute = node.attribute | |
self.neighbours = node.neighbours | |
return | |
for neighbour in self.neighbours.values(): | |
if neighbour == None: | |
neighbour = node | |
def push_neighbours(self, vertice, node): | |
if self.empty(): | |
self.stamp = node.stamp | |
self.attribute = node.attribute | |
self.neighbours = node.neighbours | |
return | |
if self.neighbours[vertice] != None: | |
raise ValueError( | |
'An already decision was made for vertice {} decision {}'. | |
format(vertice, self.neighbours[vertice]) | |
) | |
self.neighbours[vertice] = node | |
def neighbour(self, vertice): | |
return self.neighbours[vertice] | |
def node_enighbours_are_classified(self): | |
""" | |
Do we still have neighbours that needs | |
to be classified. If yes then return False | |
else return True | |
""" | |
if self.empty(): | |
return False | |
for v in self.neighbours.values(): | |
if v == None: | |
return False | |
return True | |
class Tree(object): | |
""" | |
Tree is a general purpose tree that holds | |
Id3 nodes | |
""" | |
# starting node | |
root = None | |
# maintain the current node | |
current = None | |
def empty(self): | |
return (self.root == None and self.current == None) | |
def classify(self, attributes, instance): | |
node = self.root # take the root node | |
decision = None | |
while node != None : | |
idx = attributes.index(node.attribute) # get the node attribute | |
vertice = instance[idx] # retrieve the instance vertice of that attribute | |
value = node.neighbour(vertice) # take the decision/node | |
# if this is not a Node it's an decision , take it and stop | |
if not isinstance(value, Node): | |
decision = value | |
break | |
# this means we have a node, not a decision | |
node = value | |
return decision | |
def push(self, new_node): | |
""" | |
push the new node maintaining | |
the root and current balance | |
""" | |
if self.empty(): | |
self.root = new_node | |
self.current = new_node | |
return | |
# push the node to the current not classify neighbour | |
self.current.push(new_node) | |
def push_neighbours(self, vertice, new_node): | |
if self.empty(): | |
raise ValueError("Can't push neighbours to a empty tree") | |
self.current.push_neighbours(vertice, new_node) | |
if not new_node.node_enighbours_are_classified(): | |
if isinstance(value): | |
self.current = new_node | |
def id3(data, attributes): | |
tree = Tree() | |
attribute, ds = pick_best_attribute(attributes, data) | |
node = Node(attribute, ds[attribute]) | |
tree.push(node) | |
if node.node_enighbours_are_classified(): | |
return tree.root | |
for vertice in node.vertices: | |
subset = pick_subset(data, attributes, attribute, vertice) | |
node = id3(*subset) | |
tree.push_neighbours(vertice, node) | |
return tree | |
def make_decisions(test_data, attributes, tree): | |
decisions = [] | |
for row in test_data: | |
decisions.append(tree.classify(attributes, row)) | |
return decisions | |
def main(): | |
if len(sys.argv) < 2 or sys.argv[1] == None: | |
raise ValueError("Please specify csv data file") | |
name = sys.argv[1] | |
data = None | |
with open(name, mode='r') as file: | |
r = csv.reader(file) | |
data = [row for row in r] | |
attributes = data[0] | |
data = data[1:] | |
tree = id3(data,attributes) | |
print(tree.root) | |
test_data = [[ '0', '1', '1', '1'], | |
[ '0', '1', '0', '1'], | |
[ '1', '1', '0', '0']] | |
decisions = make_decisions(test_data, attributes, tree) | |
print() | |
for attr in attributes: | |
print('{} '.format(attr), end='') | |
print() | |
for i, d in enumerate(test_data): | |
for k in d: | |
print('{} '.format(k), end='') | |
print('Decision: {} '.format( decisions[i])) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
all: | |
./id3.py data.csv | |
graph: | |
./graph.py data.csv |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment