Skip to content

Instantly share code, notes, and snippets.

@ClaytonJY
Last active October 5, 2017 19:47
Show Gist options
  • Save ClaytonJY/04848a34a598103df2348ef6f91a17dc to your computer and use it in GitHub Desktop.
Save ClaytonJY/04848a34a598103df2348ef6f91a17dc to your computer and use it in GitHub Desktop.
"grouped" vfolding in rsample
library(dplyr)
library(purrr)
library(rsample)
# suppose we want to keep cylinder-groups together
# we'll vfold those instead of the whole thing
initial_fold <- mtcars %>%
distinct(cyl) %>%
vfold_cv(v = 3)
# take a look
initial_fold %>%
pull(splits) %>%
map(assessment)
# take an existing rsplit object and expand given the original it was subsampled from
expand_split <- function(split, orig) {
vars <- colnames(split$data)
rows_in <- which(pull(orig, vars) %in% pull(analysis(split), vars))
rows_out <- which(pull(orig, vars) %in% pull(assessment(split), vars))
rsample:::rsplit(orig, sample(rows_in), sample(rows_out)) # can't forget to shuffle
}
# now apply that to each split
expanded_fold <- initial_fold %>%
mutate(splits = map(splits, expand_split, mtcars))
# take a look
expanded_fold %>%
pull(splits) %>%
map(assessment)
# other side
expanded_fold %>%
pull(splits) %>%
map(analysis)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment