predict_profile {survex}R Documentation

Instance Level Profile as Ceteris Paribus for Survival Models

Description

This function calculates Ceteris Paribus Profiles for a specific observation with the possibility to take the time dimension into account.

Usage

predict_profile(
  explainer,
  new_observation,
  variables = NULL,
  categorical_variables = NULL,
  ...,
  type = "ceteris_paribus",
  output_type = "survival",
  variable_splits_type = "uniform",
  center = FALSE
)

## S3 method for class 'surv_explainer'
predict_profile(
  explainer,
  new_observation,
  variables = NULL,
  categorical_variables = NULL,
  ...,
  type = "ceteris_paribus",
  output_type = "survival",
  variable_splits_type = "uniform",
  center = FALSE
)

Arguments

explainer

an explainer object - model preprocessed by the explain() function

new_observation

a new observation for which the prediction need to be explained

variables

a character vector containing names of variables to be explained

categorical_variables

a character vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the variables argument, they will be added at the end.

...

additional parameters passed to DALEX::predict_profile if output_type =="risk"

type

character, only "ceteris_paribus" is implemented

output_type

either "survival", "chf" or "risk" the type of survival model output that should be considered for explanations. If "survival" the explanations are based on the survival function. If "chf" the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the DALEX::predict_profile function.

variable_splits_type

character, decides how variable grids should be calculated. Use "quantiles" for percentiles or "uniform" (default) to get uniform grid of points.

center

logical, should profiles be centered around the average prediction

Value

An object of class c("predict_profile_survival", "surv_ceteris_paribus"). It is a list with the final result in the result element.

Examples


library(survival)
library(survex)

cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE)
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)

cph_exp <- explain(cph)
rsf_src_exp <- explain(rsf_src)

cph_predict_profile <- predict_profile(cph_exp, veteran[2, -c(3, 4)],
    variables = c("trt", "celltype", "karno", "age"),
    categorical_variables = "trt"
)
plot(cph_predict_profile, facet_ncol = 2)


rsf_predict_profile <- predict_profile(rsf_src_exp, veteran[5, -c(3, 4)], variables = "karno")
plot(cph_predict_profile, numerical_plot_type = "contours")



[Package survex version 1.2.0 Index]