Created
July 30, 2021 14:03
-
-
Save davidtedfordholt/3c7e0afdae183bdb0ce705543e561a00 to your computer and use it in GitHub Desktop.
A function to allow sampling of groups in dataframes, tibbles and tsibbles, in order to create easy train and test sets on time series data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Sample groups randomly in a grouped data frame | |
#' | |
#' @param .data dataframe | |
#' @param ... names of key variables to define groups. If unspecified, dataframe must be grouped. | |
#' @param n if integer, number of unique groups to keep. If between 1 and 0, proportion of unique groups to keep. | |
#' | |
#' @return tibble of all rows for sampled groups | |
#' @export | |
#' | |
#' @examples | |
#' data <- pull_tender_data_for_forecasting( | |
#' level = 'zip3', | |
#' metric = 'tenders', | |
#' first_date = '2020-11-01') | |
#' data %>% | |
#' distinct(id, direction) | |
#' data %>% | |
#' group_by(id, direction) %>% | |
#' sample_groups(20) | |
sample_groups <- | |
function(.data, ..., n) { | |
UseMethod("sample_groups") | |
} | |
#' @export | |
sample_groups.data.frame <- | |
function(.data, ..., n) { | |
keys <- unname(unlist(purrr::map_chr(rlang::exprs(...), as.character))) | |
if (length(keys) == 0) { | |
rlang::abort("`sample_groups` requires either a grouped dataframe or columns to define groups") | |
} | |
sample_groups_engine(.data, keys, n) | |
} | |
#' @export | |
sample_groups.tbl_ts <- | |
function(.data, ..., n) { | |
dots <- unname(unlist(purrr::map_chr(rlang::exprs(...), as.character))) | |
if (length(dots) != 0) { | |
keys <- dots | |
} else { | |
keys <- tsibble::key_vars(.data) | |
} | |
if (length(keys) == 0) { | |
rlang::abort("`sample_groups` requires either a grouped dataframe or columns to define groups") | |
} | |
sample_groups_engine(.data, keys, n) %>% | |
tsibble::as_tsibble( | |
key = tsibble::key_vars(.data), | |
index = !!tsibble::index_var(.data)) %>% | |
dplyr::ungroup() | |
} | |
#' @export | |
sample_groups.grouped_df <- | |
function(.data, ..., n) { | |
dots <- unname(unlist(purrr::map_chr(rlang::exprs(...), as.character))) | |
if (length(dots) != 0) { | |
keys <- dots | |
rlang::inform("sampling from specified variables, rather than grouped variables") | |
} else { | |
keys <- dplyr::group_vars(.data) | |
} | |
sample_groups_engine(.data, keys, n) | |
} | |
#' @export | |
sample_groups.grouped_ts <- | |
function(.data, ..., n) { | |
dots <- unname(unlist(purrr::map_chr(rlang::exprs(...), as.character))) | |
if (length(dots) != 0) { | |
keys <- c(dots, tsibble::key_vars(.data)) | |
rlang::inform("sampling from key(s) and specified variables, rather than grouped variables") | |
} else { | |
if (any(tsibble::key_vars(.data) %not_in% dplyr::group_vars(.data))) { | |
rlang::inform("key variables are being added to the grouped variables") | |
} | |
keys <- unique(c(dplyr::group_vars(.data), tsibble::key_vars(.data))) | |
} | |
if (length(keys) == 0) { | |
rlang::abort("`sample_groups` requires either a grouped dataframe or columns to define groups") | |
} | |
sample_groups_engine(.data, keys, n) %>% | |
tsibble::as_tsibble( | |
key = tsibble::key_vars(.data), | |
index = !!tsibble::index_var(.data)) | |
} | |
sample_groups_engine <- | |
function(.data, keys, n) { | |
# parse n as either number of samples or proportion of samples | |
if (n < 1 && n > 0) { | |
n <- ceiling(length(keys) * n) | |
} else if (n <= 0 || as.integer(n) != n) { | |
rlang::abort("`n` must be an integer greater than 0 or a number between 0 and 1.") | |
} | |
# sample n groups | |
.data %>% | |
as_tibble() %>% | |
ungroup() %>% | |
select(all_of(keys)) %>% | |
distinct() %>% | |
slice_sample(n = n) %>% | |
left_join(.data, by = keys) %>% | |
select(all_of(names(.data))) %>% | |
group_by(across(group_vars(.data))) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment