cross_validate {origami} | R Documentation |
Main Cross-Validation Function
Description
Applies cv_fun
to the folds using future_lapply
and combines
the results across folds using combine_results
.
Usage
cross_validate(
cv_fun,
folds,
...,
use_future = TRUE,
.combine = TRUE,
.combine_control = list(),
.old_results = NULL
)
Arguments
cv_fun |
A function that takes a 'fold' as it's first argument and
returns a list of results from that fold. NOTE: the use of an argument
named 'X' is specifically disallowed in any input function for compliance
with the functions |
folds |
A list of folds to loop over generated using
|
... |
Other arguments passed to |
use_future |
A |
.combine |
A |
.combine_control |
A |
.old_results |
A |
Value
A list
of results, combined across folds.
Examples
###############################################################################
# This example explains how to use the cross_validate function naively.
###############################################################################
data(mtcars)
# resubstitution MSE
r <- lm(mpg ~ ., data = mtcars)
mean(resid(r)^2)
# function to calculate cross-validated squared error
cv_lm <- function(fold, data, reg_form) {
# get name and index of outcome variable from regression formula
out_var <- as.character(unlist(stringr::str_split(reg_form, " "))[1])
out_var_ind <- as.numeric(which(colnames(data) == out_var))
# split up data into training and validation sets
train_data <- training(data)
valid_data <- validation(data)
# fit linear model on training set and predict on validation set
mod <- lm(as.formula(reg_form), data = train_data)
preds <- predict(mod, newdata = valid_data)
# capture results to be returned as output
out <- list(
coef = data.frame(t(coef(mod))),
SE = ((preds - valid_data[, out_var_ind])^2)
)
return(out)
}
# replicate the resubstitution estimate
resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]]
resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .")
mean(resub_results$SE)
# cross-validated estimate
folds <- make_folds(mtcars)
cv_results <- cross_validate(
cv_fun = cv_lm, folds = folds, data = mtcars,
reg_form = "mpg ~ ."
)
mean(cv_results$SE)
###############################################################################
# This example explains how to use the cross_validate function with
# parallelization using the framework of the future package.
###############################################################################
suppressMessages(library(data.table))
library(future)
data(mtcars)
set.seed(1)
# make a lot of folds
folds <- make_folds(mtcars, fold_fun = folds_bootstrap, V = 1000)
# function to calculate cross-validated squared error for linear regression
cv_lm <- function(fold, data, reg_form) {
# get name and index of outcome variable from regression formula
out_var <- as.character(unlist(str_split(reg_form, " "))[1])
out_var_ind <- as.numeric(which(colnames(data) == out_var))
# split up data into training and validation sets
train_data <- training(data)
valid_data <- validation(data)
# fit linear model on training set and predict on validation set
mod <- lm(as.formula(reg_form), data = train_data)
preds <- predict(mod, newdata = valid_data)
# capture results to be returned as output
out <- list(
coef = data.frame(t(coef(mod))),
SE = ((preds - valid_data[, out_var_ind])^2)
)
return(out)
}
plan(sequential)
time_seq <- system.time({
results_seq <- cross_validate(
cv_fun = cv_lm, folds = folds, data = mtcars,
reg_form = "mpg ~ ."
)
})
plan(multicore)
time_mc <- system.time({
results_mc <- cross_validate(
cv_fun = cv_lm, folds = folds, data = mtcars,
reg_form = "mpg ~ ."
)
})
if (availableCores() > 1) {
time_mc["elapsed"] < 1.2 * time_seq["elapsed"]
}