explain_survival {survex} | R Documentation |
A model-agnostic explainer for survival models
Description
Black-box models have vastly different structures. explain_survival()
returns an explainer object that can be further processed for creating
prediction explanations and their visualizations. This function is used to manually
create explainers for models not covered by the survex
package. For selected
models the extraction of information can be done automatically. To do
this, you can call the explain()
function for survival models from mlr3proba
, censored
,
randomForestSRC
, ranger
, survival
packages and any other model
with pec::predictSurvProb()
method.
Usage
explain_survival(
model,
data = NULL,
y = NULL,
predict_function = NULL,
predict_function_target_column = NULL,
residual_function = NULL,
weights = NULL,
...,
label = NULL,
verbose = TRUE,
colorize = !isTRUE(getOption("knitr.in.progress")),
model_info = NULL,
type = NULL,
times = NULL,
times_generation = "survival_quantiles",
predict_survival_function = NULL,
predict_cumulative_hazard_function = NULL
)
explain(
model,
data = NULL,
y = NULL,
predict_function = NULL,
predict_function_target_column = NULL,
residual_function = NULL,
weights = NULL,
...,
label = NULL,
verbose = TRUE,
colorize = !isTRUE(getOption("knitr.in.progress")),
model_info = NULL,
type = NULL
)
## Default S3 method:
explain(
model,
data = NULL,
y = NULL,
predict_function = NULL,
predict_function_target_column = NULL,
residual_function = NULL,
weights = NULL,
...,
label = NULL,
verbose = TRUE,
colorize = !isTRUE(getOption("knitr.in.progress")),
model_info = NULL,
type = NULL
)
Arguments
model |
object - a survival model to be explained |
data |
data.frame - data which will be used to calculate the explanations. If not provided, then it will be extracted from the model if possible. It should not contain the target columns. NOTE: If the target variable is present in the |
y |
|
predict_function |
function taking 2 arguments - |
predict_function_target_column |
unused, left for compatibility with DALEX |
residual_function |
unused, left for compatibility with DALEX |
weights |
unused, left for compatibility with DALEX |
... |
additional arguments, passed to |
label |
character - the name of the model. Used to differentiate on visualizations with multiple explainers. By default it's extracted from the 'class' attribute of the model if possible. |
verbose |
logical, if TRUE (default) then diagnostic messages will be printed |
colorize |
logical, if TRUE (default) then WARNINGS, ERRORS and NOTES are colorized. Will work only in the R console. By default it is FALSE while knitting and TRUE otherwise. |
model_info |
a named list ( |
type |
type of a model, by default |
times |
numeric, a vector of times at which the survival function and cumulative hazard function should be evaluated for calculations |
times_generation |
either |
predict_survival_function |
function taking 3 arguments |
predict_cumulative_hazard_function |
function taking 3 arguments |
Value
It is a list containing the following elements:
-
model
- the explained model. -
data
- the dataset used for training. -
y
- response for observations fromdata
. -
residuals
- calculated residuals. -
predict_function
- function that may be used for model predictions, shall return a single numerical value for each observation. -
residual_function
- function that returns residuals, shall return a single numerical value for each observation. -
class
- class/classes of a model. -
label
- label of explainer. -
model_info
- named list containing basic information about model, like package, version of package and type. -
times
- a vector of times, that are used for evaluation of survival function and cumulative hazard function by default -
predict_survival_function
- function that is used for model predictions in the form of survival function -
predict_cumulative_hazard_function
- function that is used for model predictions in the form of cumulative hazard function
Examples
library(survival)
library(survex)
cph <- survival::coxph(survival::Surv(time, status) ~ .,
data = veteran,
model = TRUE, x = TRUE
)
cph_exp <- explain(cph)
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ .,
data = veteran,
respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5
)
rsf_ranger_exp <- explain(rsf_ranger,
data = veteran[, -c(3, 4)],
y = Surv(veteran$time, veteran$status)
)
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)
rsf_src_exp <- explain(rsf_src)
library(censored, quietly = TRUE)
bt <- parsnip::boost_tree() %>%
parsnip::set_engine("mboost") %>%
parsnip::set_mode("censored regression") %>%
generics::fit(survival::Surv(time, status) ~ ., data = veteran)
bt_exp <- explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status))
###### explain_survival() ######
cph <- coxph(Surv(time, status) ~ ., data = veteran)
veteran_data <- veteran[, -c(3, 4)]
veteran_y <- Surv(veteran$time, veteran$status)
risk_pred <- function(model, newdata) predict(model, newdata, type = "risk")
surv_pred <- function(model, newdata, times) pec::predictSurvProb(model, newdata, times)
chf_pred <- function(model, newdata, times) -log(surv_pred(model, newdata, times))
manual_cph_explainer <- explain_survival(
model = cph,
data = veteran_data,
y = veteran_y,
predict_function = risk_pred,
predict_survival_function = surv_pred,
predict_cumulative_hazard_function = chf_pred,
label = "manual coxph"
)