cv_misvm {mildsvm} | R Documentation |
Fit MI-SVM model to the data using cross-validation
Description
Cross-validation wrapper on the misvm()
function to fit the MI-SVM model
over a variety of specified cost parameters. The optimal cost parameter
is chosen by the best AUC of the cross-fit models. See ?misvm
for
more details on the fitting function.
Usage
## Default S3 method:
cv_misvm(
x,
y,
bags,
cost_seq,
n_fold,
fold_id,
method = c("heuristic", "mip", "qp-heuristic"),
weights = TRUE,
control = list(kernel = "linear", sigma = 1, nystrom_args = list(m = nrow(x), r =
nrow(x), sampling = "random"), max_step = 500, type = "C-classification", scale =
TRUE, verbose = FALSE, time_limit = 60, start = FALSE),
...
)
## S3 method for class 'formula'
cv_misvm(formula, data, cost_seq, n_fold, fold_id, ...)
## S3 method for class 'mi_df'
cv_misvm(x, ...)
Arguments
x |
A data.frame, matrix, or similar object of covariates, where each row represents a sample. |
y |
A numeric, character, or factor vector of bag labels for each
instance. Must satisfy |
bags |
A vector specifying which instance belongs to each bag. Can be a string, numeric, of factor. |
cost_seq |
A sequence of |
n_fold |
The number of folds (default 5). If this is specified,
|
fold_id |
The ids for the specific the fold for each instance. Care must
be taken to ensure that ids respect the bag structure to avoid information
leakage. If |
method |
The algorithm to use in fitting (default |
weights |
named vector, or |
control |
list of additional parameters passed to the method that control computation with the following components:
|
... |
Arguments passed to or from other methods. |
formula |
a formula with specification |
data |
If |
Value
An object of class cv_misvm
. The object contains the following
components:
-
misvm_fit
: A fit object of classmisvm
trained on the full data with the cross-validated choice of cost parameter. Seemisvm()
for details. -
cost_seq
: the input sequence of cost arguments -
cost_aucs
: estimated AUC for the models trained for eachcost_seq
parameter. These are the average of the fold models for that cost, excluding any folds that don't have both levels ofy
in the validation set. -
best_cost
: The optimal choice of cost parameter, chosen as that which has the maximum AUC. If there are ties, this will pick the smallest cost with maximum AUC.
Methods (by class)
-
default
: Method for data.frame-like objects -
formula
: Method for passing formula -
mi_df
: Method formi_df
objects, automatically handling bag names, labels, and all covariates.
Author(s)
Sean Kent, Yifei Liu
See Also
misvm()
for fitting without cross-validation.
Examples
set.seed(8)
mil_data <- generate_mild_df(nbag = 20,
positive_prob = 0.15,
dist = rep("mvnormal", 3),
mean = list(rep(1, 10), rep(2, 10)),
sd_of_mean = rep(0.1, 3))
df <- build_instance_feature(mil_data, seq(0.05, 0.95, length.out = 10))
cost_seq <- 2^seq(-5, 7, length.out = 3)
# Heuristic method
mdl1 <- cv_misvm(x = df[, 4:123], y = df$bag_label,
bags = df$bag_name, cost_seq = cost_seq,
n_fold = 3, method = "heuristic")
mdl2 <- cv_misvm(mi(bag_label, bag_name) ~ X1_mean + X2_mean + X3_mean, data = df,
cost_seq = cost_seq, n_fold = 3)
if (require(gurobi)) {
# solve using the MIP method
mdl3 <- cv_misvm(x = df[, 4:123], y = df$bag_label,
bags = df$bag_name, cost_seq = cost_seq,
n_fold = 3, method = "mip")
}
predict(mdl1, new_data = df, type = "raw", layer = "bag")
# summarize predictions at the bag layer
suppressWarnings(library(dplyr))
df %>%
bind_cols(predict(mdl2, df, type = "class")) %>%
bind_cols(predict(mdl2, df, type = "raw")) %>%
distinct(bag_name, bag_label, .pred_class, .pred)