Last active
March 8, 2017 12:38
-
-
Save christian-rauch/7ce87819347acfd9168ef08822a04559 to your computer and use it in GitHub Desktop.
Prune (cut) scikit's RandomForestClassifier at a given depth
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This will insert leaf nodes at a given depth of the tree, e.g. the decision path will end at this depth. | |
# It does not actually remove the nodes from the list in 'children_left' and 'children_right', | |
# e.g. the split nodes will stay in the tree but will not be used within a decision path. | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.tree.tree import Tree | |
def prune_node(tree, node_id, parent_depth, prune_depth): | |
this_depth = parent_depth+1 # root node at depth 1 | |
if (tree.children_left[node_id]==-1 and tree.children_right[node_id]==-1) and (tree.feature[node_id]==-2 and tree.threshold[node_id]==-2): | |
# we are at leaf node | |
return | |
else: | |
# we are at split node | |
if this_depth > prune_depth: | |
# cut here, e.g. make leaf node | |
tree.children_left[node_id] = -1 | |
tree.children_right[node_id] = -1 | |
tree.feature[node_id] = -2 | |
tree.threshold[node_id] = -2 | |
return | |
else: | |
# continue search | |
prune_node(tree, tree.children_left[node_id], this_depth, prune_depth) | |
prune_node(tree, tree.children_right[node_id], this_depth, prune_depth) | |
return | |
# prune trees, e.g. insert leaf nodes at prune_depth+1 | |
# the 'value' (class histogram) at the node will stay the same and | |
def prune_tree(tree, prune_depth): | |
current_depth = 0 | |
prune_node(tree, 0, current_depth, prune_depth) | |
# cut each tree of forest after 'prune_depth' | |
# this will directly change the provided forest | |
def prune_forest(forest, prune_depth): | |
for estimator in forest.estimators_: | |
prune_tree(estimator.tree_, prune_depth) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment