ranger_surv.unify {treeshap} | R Documentation |
Unify ranger survival model
Description
Convert your ranger model into a standardized representation.
The returned representation is easy to be interpreted by the user and ready to be used as an argument in treeshap()
function.
Usage
ranger_surv.unify(
rf_model,
data,
type = c("risk", "survival", "chf"),
times = NULL
)
Arguments
rf_model |
An object of |
data |
Reference dataset. A |
type |
A character to define the type of model prediction to use. Either |
times |
A numeric vector of unique death times at which the prediction should be evaluated. By default |
Details
The survival forest implemented in the ranger
package stores cumulative hazard
functions (CHFs) in the leaves of survival trees, as proposed for Random Survival Forests
(Ishwaran et al. 2008). The final model prediction is made by averaging these CHFs
from all the trees. To provide explanations in the form of a survival function,
the CHFs from the leaves are converted into survival functions (SFs) using
the formula SF(t) = exp(-CHF(t)).
However, it is important to note that averaging these SFs does not yield the correct
model prediction as the model prediction is the average of CHFs transformed in the same way.
Therefore, when you obtain explanations based on the survival function,
they are only proxies and may not be fully consistent with the model predictions
obtained using for example predict
function.
Value
For type = "risk"
a unified model representation is returned - a model_unified.object
object. For type = "survival"
or type = "chf"
- a model_unified_multioutput.object
object is returned, which is a list that contains unified model representation (model_unified.object
object) for each time point. In this case, the list names are time points at which the survival function was evaluated.
See Also
ranger.unify
for regression and classification ranger models
lightgbm.unify
for LightGBM models
gbm.unify
for GBM models
xgboost.unify
for XGBoost models
randomForest.unify
for randomForest models
Examples
library(ranger)
data_colon <- data.table::data.table(survival::colon)
data_colon <- na.omit(data_colon[get("etype") == 2, ])
surv_cols <- c("status", "time", "rx")
feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)]
train_x <- model.matrix(
~ -1 + .,
data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])]
)
train_y <- survival::Surv(
event = (data_colon[, get("status")] |>
as.character() |>
as.integer()),
time = data_colon[, get("time")],
type = "right"
)
rf <- ranger::ranger(
x = train_x,
y = train_y,
data = data_colon,
max.depth = 10,
num.trees = 10
)
unified_model_risk <- ranger_surv.unify(rf, train_x, type = "risk")
shaps <- treeshap(unified_model_risk, train_x[1:2,])
# compute shaps for 3 selected time points
unified_model_surv <- ranger_surv.unify(rf, train_x, type = "survival", times = c(23, 50, 73))
shaps_surv <- treeshap(unified_model_surv, train_x[1:2,])