predictions {mlexperiments}R Documentation

predictions

Description

Apply an R6 object of class "MLCrossValidation" to new data to compute predictions.

Usage

predictions(object, newdata, na.rm = FALSE, ncores = -1L, ...)

Arguments

object

An R6 object of class "MLCrossValidation" for which the predictions should be computed.

newdata

The new data for which predictions should be made using the model.

na.rm

A logical. If missings should be removed before computing the mean and standard deviation of the performance across different folds for each observation in newdata.

ncores

An integer to specify the number of cores used for parallelization (default: -1L).

...

A list. Further arguments required to compute the predictions.

Value

The function returns a data.table of class "mlexPredictions"with one row for each observation in newdata and the columns containing the predictions for each fold, along with the mean and standard deviation across all folds.

Examples

dataset <- do.call(
  cbind,
  c(sapply(paste0("col", 1:6), function(x) {
    rnorm(n = 500)
    },
    USE.NAMES = TRUE,
    simplify = FALSE
   ),
   list(target = sample(0:1, 500, TRUE))
))

fold_list <- splitTools::create_folds(
  y = dataset[, 7],
  k = 3,
  type = "stratified",
  seed = 123
)

glm_optimization <- mlexperiments::MLCrossValidation$new(
  learner = LearnerGlm$new(),
  fold_list = fold_list,
  seed = 123
)

glm_optimization$learner_args <- list(family = binomial(link = "logit"))
glm_optimization$predict_args <- list(type = "response")
glm_optimization$performance_metric_args <- list(positive = "1")
glm_optimization$performance_metric <- metric("auc")
glm_optimization$return_models <- TRUE

# set data
glm_optimization$set_data(
  x = data.matrix(dataset[, -7]),
  y = dataset[, 7]
)

cv_results <- glm_optimization$execute()

# predictions
preds <- mlexperiments::predictions(
  object = glm_optimization,
  newdata = data.matrix(dataset[, -7]),
  na.rm = FALSE,
  ncores = 2L,
  type = "response"
)
head(preds)


[Package mlexperiments version 0.0.3 Index]