eSHAP_plot_reg {explainer} | R Documentation |
Enhanced SHAP Analysis for Regression Models
Description
The SHAP plot for regression models is a visualization tool that uses the Shapley value, an approach from cooperative game theory, to compute feature contributions for single predictions. The Shapley value fairly distributes the difference of the instance’s prediction and the datasets average prediction among the features. This method is available from the iml package.
Usage
eSHAP_plot_reg(
task,
trained_model,
splits,
sample.size = 30,
seed = 246,
subset = 1
)
Arguments
task |
mlr3 regression task object specifying the task details |
trained_model |
mlr3 trained learner (model) object obtained after training |
splits |
mlr3 object defining data splits for train and test sets |
sample.size |
numeric, number of samples to calculate SHAP values (default: 30) |
seed |
numeric, seed for reproducibility (default: 246) |
subset |
numeric, proportion of the test set to use for visualization (default: 1) |
Value
A list of two objects:
An enhanced SHAP plot with user interactive elements,
A matrix of SHAP values
Examples
library("explainer")
seed <- 246
set.seed(seed)
# Load necessary packages
if (!requireNamespace("mlbench", quietly = TRUE)) stop("mlbench not installed.")
if (!requireNamespace("mlr3learners", quietly = TRUE)) stop("mlr3learners not installed.")
if (!requireNamespace("ranger", quietly = TRUE)) stop("ranger not installed.")
# Load BreastCancer dataset
utils::data("BreastCancer", package = "mlbench")
mydata <- BreastCancer[, -1]
mydata <- na.omit(mydata)
sex <- sample(c("Male", "Female"), size = nrow(mydata), replace = TRUE)
mydata$age <- sample(seq(18, 60), size = nrow(mydata), replace = TRUE)
mydata$sex <- factor(sex, levels = c("Male", "Female"), labels = c(1, 0))
mydata$Class <- NULL
mydata$Cl.thickness <- as.numeric(mydata$Cl.thickness)
target_col <- "Cl.thickness"
maintask <- mlr3::TaskRegr$new(
id = "my_regression_task",
backend = mydata,
target = target_col
)
splits <- mlr3::partition(maintask)
mylrn <- mlr3::lrn("regr.ranger", predict_type = "response")
mylrn$train(maintask, splits$train)
reg_model_outputs <- mylrn$predict(maintask, splits$test)
SHAP_output <- eSHAP_plot_reg(
task = maintask,
trained_model = mylrn,
splits = splits,
sample.size = 2, # also 30 or more
seed = seed,
subset = 0.02 # up to 1
)
myplot <- SHAP_output[[1]]