explain_survival {survex}R Documentation

A model-agnostic explainer for survival models

Description

Black-box models have vastly different structures. explain_survival() returns an explainer object that can be further processed for creating prediction explanations and their visualizations. This function is used to manually create explainers for models not covered by the survex package. For selected models the extraction of information can be done automatically. To do this, you can call the explain() function for survival models from mlr3proba, censored, randomForestSRC, ranger, survival packages and any other model with pec::predictSurvProb() method.

Usage

explain_survival(
  model,
  data = NULL,
  y = NULL,
  predict_function = NULL,
  predict_function_target_column = NULL,
  residual_function = NULL,
  weights = NULL,
  ...,
  label = NULL,
  verbose = TRUE,
  colorize = !isTRUE(getOption("knitr.in.progress")),
  model_info = NULL,
  type = NULL,
  times = NULL,
  times_generation = "survival_quantiles",
  predict_survival_function = NULL,
  predict_cumulative_hazard_function = NULL
)

explain(
  model,
  data = NULL,
  y = NULL,
  predict_function = NULL,
  predict_function_target_column = NULL,
  residual_function = NULL,
  weights = NULL,
  ...,
  label = NULL,
  verbose = TRUE,
  colorize = !isTRUE(getOption("knitr.in.progress")),
  model_info = NULL,
  type = NULL
)

## Default S3 method:
explain(
  model,
  data = NULL,
  y = NULL,
  predict_function = NULL,
  predict_function_target_column = NULL,
  residual_function = NULL,
  weights = NULL,
  ...,
  label = NULL,
  verbose = TRUE,
  colorize = !isTRUE(getOption("knitr.in.progress")),
  model_info = NULL,
  type = NULL
)

Arguments

model

object - a survival model to be explained

data

data.frame - data which will be used to calculate the explanations. If not provided, then it will be extracted from the model if possible. It should not contain the target columns. NOTE: If the target variable is present in the data some functionality breaks.

y

survival::Surv object containing event/censoring times and statuses corresponding to data

predict_function

function taking 2 arguments - model and newdata and returning a single number for each observation - risk score. Observations with higher score are more likely to observe the event sooner.

predict_function_target_column

unused, left for compatibility with DALEX

residual_function

unused, left for compatibility with DALEX

weights

unused, left for compatibility with DALEX

...

additional arguments, passed to DALEX::explain()

label

character - the name of the model. Used to differentiate on visualizations with multiple explainers. By default it's extracted from the 'class' attribute of the model if possible.

verbose

logical, if TRUE (default) then diagnostic messages will be printed

colorize

logical, if TRUE (default) then WARNINGS, ERRORS and NOTES are colorized. Will work only in the R console. By default it is FALSE while knitting and TRUE otherwise.

model_info

a named list (package, version, type) containing information about model. If NULL, survex will seek for information on its own.

type

type of a model, by default "survival"

times

numeric, a vector of times at which the survival function and cumulative hazard function should be evaluated for calculations

times_generation

either "survival_quantiles", "uniform" or "quantiles". Sets the way of generating the vector of times based on times provided in the y parameter. If "survival_quantiles" the vector contains unique time points out of 50 uniformly distributed survival quantiles based on the Kaplan-Meier estimator, and additional time point being the median survival time (if possible); if "uniform" the vector contains 50 equally spaced time points between the minimum and maximum observed times; if "quantiles" the vector contains unique time points out of 50 time points between 0th and 98th percentiles of observed times. Ignored if times is not NULL.

predict_survival_function

function taking 3 arguments model, newdata and times, and returning a matrix whose each row is a survival function evaluated at times for one observation from newdata

predict_cumulative_hazard_function

function taking 3 arguments model, newdata and times, and returning a matrix whose each row is a cumulative hazard function evaluated at times for one observation from newdata

Value

It is a list containing the following elements:

Examples


library(survival)
library(survex)

cph <- survival::coxph(survival::Surv(time, status) ~ .,
    data = veteran,
    model = TRUE, x = TRUE
)
cph_exp <- explain(cph)

rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ .,
    data = veteran,
    respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5
)
rsf_ranger_exp <- explain(rsf_ranger,
    data = veteran[, -c(3, 4)],
    y = Surv(veteran$time, veteran$status)
)

rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)
rsf_src_exp <- explain(rsf_src)

library(censored, quietly = TRUE)

bt <- parsnip::boost_tree() %>%
    parsnip::set_engine("mboost") %>%
    parsnip::set_mode("censored regression") %>%
    generics::fit(survival::Surv(time, status) ~ ., data = veteran)
bt_exp <- explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status))

###### explain_survival() ######

cph <- coxph(Surv(time, status) ~ ., data = veteran)

veteran_data <- veteran[, -c(3, 4)]
veteran_y <- Surv(veteran$time, veteran$status)
risk_pred <- function(model, newdata) predict(model, newdata, type = "risk")
surv_pred <- function(model, newdata, times) pec::predictSurvProb(model, newdata, times)
chf_pred <- function(model, newdata, times) -log(surv_pred(model, newdata, times))

manual_cph_explainer <- explain_survival(
    model = cph,
    data = veteran_data,
    y = veteran_y,
    predict_function = risk_pred,
    predict_survival_function = surv_pred,
    predict_cumulative_hazard_function = chf_pred,
    label = "manual coxph"
)



[Package survex version 1.2.0 Index]