Last active
October 1, 2020 23:18
-
-
Save bakaburg1/893ea82db683b45d71477e8535a2a72e to your computer and use it in GitHub Desktop.
Small helper functions to interact with partykit::ctree().
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Extract ctree rules. | |
#' | |
#' Extract the tree rules, optionally formatting them in order to be ready to | |
#' use for data filtering. | |
#' | |
#' @param tree An object produced by \code{partikit::ctree()}. | |
#' @param rule.as.text Whether to collapse rule conditions into a string or | |
#' leave them as a character vector. | |
#' @param eval.ready Whether to format rules in order to be easily eval() for | |
#' data filtering. Turns \code{rule.as.text} automatically on. | |
#' | |
#' @return A dataframe with the rule, the tree node id and the depth of the rule. | |
#' | |
#' @import dplyr | |
#' @import stringr | |
#' @import partykit | |
#' @importFrom readr parse_number | |
#' | |
#' @export | |
#' | |
#' @examples | |
get_ctree_rules <- function(tree, rule.as.text = T, eval.ready = F) { | |
library(dplyr) | |
library(stringr) | |
library(partykit) | |
library(readr) | |
if (length(tree) == 1) return(data.frame(rule = character(0), id = numeric(0), depth = numeric(0))) | |
out <- capture.output(tree) | |
rules <- tibble( | |
rule = out[(which(out == "[1] root") + 1):max(which(out == "") - 1)] %>% | |
str_remove('\\|\\s+'), | |
id = str_extract(rule, '\\[\\d+\\]') %>% parse_number(), | |
depth = str_count(rule, '\\|') + 1 | |
) | |
rules <- rules %>% mutate( | |
rule = rule %>% str_remove_all('\\|\\s+') %>% | |
str_remove('\\[\\d+\\]') %>% | |
str_remove(':.*') %>% str_squish() | |
) | |
if (eval.ready) { | |
rules$rule <- rules$rule %>% | |
str_replace(' in ', ' %in% ') %>% | |
str_replace_all(c(', ' = '", "', '%\\s+' = '% c("', '(\\D)$' = '\\1")')) | |
} | |
rules %>% mutate( | |
rule = sapply(id, function(this.id) { | |
ids <- id[id <= this.id & depth <= depth[id == this.id]] | |
depths <- depth[id <= this.id & depth <= depth[id == this.id]] | |
ids <- tapply(ids, depths, max) | |
rule <- rule[id %in% ids] | |
if (rule.as.text) paste(rule, collapse = ' & ') else list(rule) | |
}) | |
) | |
} | |
#' Simplify ctree rules. | |
#' | |
#' Remove redundant components of a rule keeping only the shortest set | |
#' definition (e.g.: if many conditions in a rule represent nested sets, only | |
#' those necessary to define the innermost set are kept). The conditions are | |
#' also rearranged alphabetically for easier comparison. | |
#' | |
#' @param rules A character vector of rules joined by the & symbol. | |
#' | |
#' @return The same vector of rules after simplification. | |
#' | |
#' @import dplyr | |
#' @import stringr | |
#' | |
#' @export | |
#' | |
#' @examples | |
simplify_rules <- function(rules) { | |
library(dplyr) | |
library(stringr) | |
sapply(rules, function(rule) { | |
if (rule == '') return(NA) | |
components <- str_split(rule, ' & ') %>% unlist | |
vars <- str_extract(components, '.* [<>%=in]+') %>% unique | |
ind <- sapply(vars, function(v) tail(which(str_detect(components, fixed(v))), 1)) | |
paste(components[ind] %>% sort, collapse = ' & ') | |
}) %>% na.omit | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment