predict_profile {survex} | R Documentation |
Instance Level Profile as Ceteris Paribus for Survival Models
Description
This function calculates Ceteris Paribus Profiles for a specific observation with the possibility to take the time dimension into account.
Usage
predict_profile(
explainer,
new_observation,
variables = NULL,
categorical_variables = NULL,
...,
type = "ceteris_paribus",
output_type = "survival",
variable_splits_type = "uniform",
center = FALSE
)
## S3 method for class 'surv_explainer'
predict_profile(
explainer,
new_observation,
variables = NULL,
categorical_variables = NULL,
...,
type = "ceteris_paribus",
output_type = "survival",
variable_splits_type = "uniform",
center = FALSE
)
Arguments
explainer |
an explainer object - model preprocessed by the |
new_observation |
a new observation for which the prediction need to be explained |
variables |
a character vector containing names of variables to be explained |
categorical_variables |
a character vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the |
... |
additional parameters passed to |
type |
character, only |
output_type |
either |
variable_splits_type |
character, decides how variable grids should be calculated. Use |
center |
logical, should profiles be centered around the average prediction |
Value
An object of class c("predict_profile_survival", "surv_ceteris_paribus")
. It is a list with the final result in the result
element.
Examples
library(survival)
library(survex)
cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE)
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)
cph_exp <- explain(cph)
rsf_src_exp <- explain(rsf_src)
cph_predict_profile <- predict_profile(cph_exp, veteran[2, -c(3, 4)],
variables = c("trt", "celltype", "karno", "age"),
categorical_variables = "trt"
)
plot(cph_predict_profile, facet_ncol = 2)
rsf_predict_profile <- predict_profile(rsf_src_exp, veteran[5, -c(3, 4)], variables = "karno")
plot(cph_predict_profile, numerical_plot_type = "contours")