sirus.cv {sirus} | R Documentation |
Estimate p0.
Description
Estimate the optimal hyperparameter p0
used to select rules in sirus.fit
using cross-validation (Benard et al. 2021a, 2021b).
Usage
sirus.cv(
data,
y,
type = "auto",
nfold = 10,
ncv = 10,
num.rule.max = 25,
q = 10,
discrete.limit = 10,
num.trees.step = 1000,
alpha = 0.05,
mtry = NULL,
max.depth = 2,
num.trees = NULL,
num.threads = NULL,
replace = TRUE,
sample.fraction = NULL,
verbose = TRUE,
seed = NULL
)
Arguments
data |
Input dataframe, each row is an observation vector. Each column is an input variable and is numeric or factor. |
y |
Numeric response variable. For classification, |
type |
'reg' for regression, 'classif' for classification and 'auto' for automatic detection (classification if |
nfold |
Number of folds in the cross-validation. Default is 10. |
ncv |
Number of repetitions of the cross-validation. Default is 10 for a robust estimation of |
num.rule.max |
Maximum number of rules of SIRUS model in the cross-validation grid. Default is 25. |
q |
Number of quantiles used for node splitting in the forest construction. Default and recommended value is 10. |
discrete.limit |
Maximum number of distinct values for a variable to be considered discrete. If higher, variable is continuous. |
num.trees.step |
Number of trees grown between two evaluations of the stopping criterion. Ignored if |
alpha |
Parameter of the stopping criterion for the number of trees: stability has to reach 1- |
mtry |
Number of variables to possibly split at each node. Default is the number of variables divided by 3. |
max.depth |
Maximal tree depth. Default and recommended value is 2. |
num.trees |
Number of trees grown in the forest. If NULL (recommended), the number of trees is automatically set using a stability stopping criterion. |
num.threads |
Number of threads used to grow the forest. Default is number of CPUs available. |
replace |
Boolean. If true (default), sample with replacement. |
sample.fraction |
Fraction of observations to sample. Default is 1 for sampling with replacement and 0.632 for sampling without replacement. |
verbose |
Boolean. If true, information messages are printed. |
seed |
Random seed. Default is NULL, which generates the seed from R. Set to 0 to ignore the R seed. |
Details
For a robust estimation of p0
, it is recommended to run multiple cross-validations (typically ncv
= 10).
Two optimal values of p0
are provided: p0.pred
(Benard et al. 2021a) and p0.stab
(Benard et al. 2021b), defined such that p0.pred
minimizes the error, and p0.stab
finds a tradeoff between error and stability.
Error is 1-AUC for classification and the unexplained variance for regression.
Stability is the average proportion of rules shared by two SIRUS models fit on two distinct folds of the cross-validation.
Value
Optimal value of p0
with the elements
p0.pred |
Optimal |
p0.stab |
Optimal |
error.grid.p0 |
Table with the full cross-validation results for a fine grid of |
type |
'reg' for regression, 'classif' for classification. |
References
Benard, C., Biau, G., Da Veiga, S. & Scornet, E. (2021a). SIRUS: Stable and Interpretable RUle Set for Classification. Electronic Journal of Statistics, 15:427-505. doi:10.1214/20-EJS1792.
Benard, C., Biau, G., Da Veiga, S. & Scornet, E. (2021b). Interpretable Random Forests via Rule Extraction. Proceedings of The 24th International Conference on Artificial Intelligence and Statistics, PMLR 130:937-945. http://proceedings.mlr.press/v130/benard21a.
Examples
## load SIRUS
require(sirus)
## prepare data
data <- iris
y <- rep(0, nrow(data))
y[data$Species == 'setosa'] = 1
data$Species <- NULL
## run cv
cv.grid <- sirus.cv(data, y, nfold = 3, ncv = 2, num.trees = 100)