LearnerGlmnet {mllrnrs} | R Documentation |
R6 Class to construct a Glmnet learner
Description
The LearnerGlmnet
class is the interface to the glmnet
R package for use
with the mlexperiments
package.
Details
Optimization metric: Can be used with
Super class
mlexperiments::MLLearnerBase
-> LearnerGlmnet
Methods
Public methods
Inherited methods
Method new()
Create a new LearnerGlmnet
object.
Usage
LearnerGlmnet$new(metric_optimization_higher_better)
Arguments
metric_optimization_higher_better
A logical. Defines the direction of the optimization metric used throughout the hyperparameter optimization.
Returns
A new LearnerGlmnet
R6 object.
Examples
LearnerGlmnet$new(metric_optimization_higher_better = FALSE)
Method clone()
The objects of this class are cloneable with this method.
Usage
LearnerGlmnet$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
See Also
glmnet::glmnet()
, glmnet::cv.glmnet()
Examples
# binary classification
library(mlbench)
data("PimaIndiansDiabetes2")
dataset <- PimaIndiansDiabetes2 |>
data.table::as.data.table() |>
na.omit()
seed <- 123
feature_cols <- colnames(dataset)[1:8]
train_x <- model.matrix(
~ -1 + .,
dataset[, .SD, .SDcols = feature_cols]
)
train_y <- as.integer(dataset[, get("diabetes")]) - 1L
fold_list <- splitTools::create_folds(
y = train_y,
k = 3,
type = "stratified",
seed = seed
)
glmnet_cv <- mlexperiments::MLCrossValidation$new(
learner = mllrnrs::LearnerGlmnet$new(
metric_optimization_higher_better = FALSE
),
fold_list = fold_list,
ncores = 2,
seed = 123
)
glmnet_cv$learner_args <- list(
alpha = 1,
lambda = 0.1,
family = "binomial",
type.measure = "class",
standardize = TRUE
)
glmnet_cv$predict_args <- list(type = "response")
glmnet_cv$performance_metric_args <- list(positive = "1")
glmnet_cv$performance_metric <- mlexperiments::metric("auc")
# set data
glmnet_cv$set_data(
x = train_x,
y = train_y
)
glmnet_cv$execute()
## ------------------------------------------------
## Method `LearnerGlmnet$new`
## ------------------------------------------------
LearnerGlmnet$new(metric_optimization_higher_better = FALSE)
[Package mllrnrs version 0.0.4 Index]