crossval {crossval} | R Documentation |
Generic Function for Cross Valdidation
Description
crossval
performs K-fold cross validation with B repetitions. If Y
is a factor then balanced sampling is used (i.e. in each fold each category is represented in appropriate proportions).
Usage
crossval(predfun, X, Y, K=10, B=20, verbose=TRUE, ...)
Arguments
predfun |
Prediction function (see details). |
X |
Matrix of predictors (columns correspond to variables). |
Y |
Univariate response variable. |
K |
Number of folds. |
B |
Number of repetitions. |
verbose |
If |
... |
optional arguments for |
Details
The argument predfun
must be a function of the form
predfun(Xtrain, Ytrain, Xtest, Ytest, ...)
.
Value
crossval
returns a list with three entries:
stat.cv: the statistic returned by predfun for each cross validation run.
stat: the statistic returned by predfun averaged over all cross validation runs.
stat.se: the corresponding standard error.
Author(s)
Korbinian Strimmer (https://strimmerlab.github.io).
See Also
Examples
# load "crossval" package
library("crossval")
# classification examples
# set up lda prediction function
predfun.lda = function(train.x, train.y, test.x, test.y, negative)
{
require("MASS") # for lda function
lda.fit = lda(train.x, grouping=train.y)
ynew = predict(lda.fit, test.x)$class
# count TP, FP etc.
out = confusionMatrix(test.y, ynew, negative=negative)
return( out )
}
# Student's Sleep Data
data(sleep)
X = as.matrix(sleep[,1, drop=FALSE]) # increase in hours of sleep
Y = sleep[,2] # drug given
plot(X ~ Y)
levels(Y) # "1" "2"
dim(X) # 20 1
set.seed(12345)
cv.out = crossval(predfun.lda, X, Y, K=5, B=20, negative="1")
cv.out$stat
diagnosticErrors(cv.out$stat)
# linear regression example
data("attitude")
y = attitude[,1] # rating variable
x = attitude[,-1] # date frame with the remaining variables
is.factor(y) # FALSE
summary( lm(y ~ . , data=x) )
# set up lm prediction function
predfun.lm = function(train.x, train.y, test.x, test.y)
{
lm.fit = lm(train.y ~ . , data=train.x)
ynew = predict(lm.fit, test.x )
# compute squared error risk (MSE)
out = mean( (ynew - test.y)^2 )
return( out )
}
# prediction MSE using all variables
set.seed(12345)
cv.out = crossval(predfun.lm, x, y, K=5, B=20)
c(cv.out$stat, cv.out$stat.se)
# and only two variables
cv.out = crossval(predfun.lm, x[,c(1,3)], y, K=5, B=20)
c(cv.out$stat, cv.out$stat.se)
# for more examples (e.g. using cross validation in a regression or classification context)
# see the R packages "sda", "care", or "binda".