Skip to content

Instantly share code, notes, and snippets.

@galbraun
Created August 9, 2020 13:06
Show Gist options
  • Save galbraun/b58b2937f130452ae81ea8d7d401bc0d to your computer and use it in GitHub Desktop.
Save galbraun/b58b2937f130452ae81ea8d7d401bc0d to your computer and use it in GitHub Desktop.
Functions to extract for a xgboost forest for each tree and each leaf - the middle nodes that create the path to reach it.
def _root_to_leaf_route(df, stack, routes):
current_node = df.loc[df.Node == stack[-1]]
if current_node.Feature.values[0] == 'Leaf':
routes[current_node.Node.values[0]] = list(stack)
stack.pop()
return
stack.append(int(current_node.Yes.values[0].split('-')[1]))
_root_to_leaf_route(df, stack, routes)
stack.append(int(current_node.No.values[0].split('-')[1]))
_root_to_leaf_route(df, stack, routes)
stack.pop()
return
def extract_root_to_leaf_routes_for_forest(xgb):
routes_forest = {}
for i in range(len(xgb.get_booster().get_dump())):
routes_forest[i] = extract_root_to_leaf_routes_for_tree(xgb, i)
return routes_forest
def extract_root_to_leaf_routes_for_tree(xgb, tree_index):
df = xgb.get_booster().trees_to_dataframe()
df = df.loc[df.Tree == tree_index].set_index('ID')
routes = {}
stack = [0]
_root_to_leaf_route(df, stack, routes)
return routes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment