LearnerSurvXgboostAft {mlsurvlrnrs} | R Documentation |
R6 Class to construct a Xgboost survival learner for accelerated failure time models
Description
The LearnerSurvXgboostAft
class is the interface to accelerated failure
time models with the xgboost
R package for use with the mlexperiments
package.
Details
Optimization metric: needs to be specified with the learner parameter
eval_metric
.
Can be used with
Also see the official xgboost documentation on aft models: https://xgboost.readthedocs.io/en/stable/tutorials/aft_survival_analysis.html
Super classes
mlexperiments::MLLearnerBase
-> mllrnrs::LearnerXgboost
-> LearnerSurvXgboostAft
Methods
Public methods
Inherited methods
Method new()
Create a new LearnerSurvXgboostAft
object.
Usage
LearnerSurvXgboostAft$new(metric_optimization_higher_better)
Arguments
metric_optimization_higher_better
A logical. Defines the direction of the optimization metric used throughout the hyperparameter optimization.
Returns
A new LearnerSurvXgboostAft
R6 object.
Examples
LearnerSurvXgboostAft$new(metric_optimization_higher_better = FALSE)
Method clone()
The objects of this class are cloneable with this method.
Usage
LearnerSurvXgboostAft$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
See Also
xgboost::xgb.train()
, xgboost::xgb.cv()
Examples
# execution time >2.5 sec
# survival analysis
dataset <- survival::colon |>
data.table::as.data.table() |>
na.omit()
dataset <- dataset[get("etype") == 2, ]
seed <- 123
surv_cols <- c("status", "time", "rx")
feature_cols <- colnames(dataset)[3:(ncol(dataset) - 1)]
param_list_xgboost <- expand.grid(
objective = "survival:aft",
eval_metric = "aft-nloglik",
subsample = seq(0.6, 1, .2),
colsample_bytree = seq(0.6, 1, .2),
min_child_weight = seq(1, 5, 4),
learning_rate = c(0.1, 0.2),
max_depth = seq(1, 5, 4)
)
ncores <- 2L
split_vector <- splitTools::multi_strata(
df = dataset[, .SD, .SDcols = surv_cols],
strategy = "kmeans",
k = 4
)
train_x <- model.matrix(
~ -1 + .,
dataset[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])]
)
train_y <- survival::Surv(
event = (dataset[, get("status")] |>
as.character() |>
as.integer()),
time = dataset[, get("time")],
type = "right"
)
fold_list <- splitTools::create_folds(
y = split_vector,
k = 3,
type = "stratified",
seed = seed
)
surv_xgboost_aft_optimizer <- mlexperiments::MLCrossValidation$new(
learner = LearnerSurvXgboostAft$new(
metric_optimization_higher_better = FALSE
),
fold_list = fold_list,
ncores = ncores,
seed = seed
)
surv_xgboost_aft_optimizer$learner_args <- c(as.list(
data.table::data.table(param_list_xgboost[1, ], stringsAsFactors = FALSE)
),
nrounds = 45L
)
surv_xgboost_aft_optimizer$performance_metric <- c_index
# set data
surv_xgboost_aft_optimizer$set_data(
x = train_x,
y = train_y
)
surv_xgboost_aft_optimizer$execute()
## ------------------------------------------------
## Method `LearnerSurvXgboostAft$new`
## ------------------------------------------------
LearnerSurvXgboostAft$new(metric_optimization_higher_better = FALSE)