SHAPclust {explainer}R Documentation

SHAP clustering

Description

SHAP values are used to cluster data samples using the k-means method to identify subgroups of individuals with specific patterns of feature contributions.

Usage

SHAPclust(
  task,
  trained_model,
  splits,
  shap_Mean_wide,
  shap_Mean_long,
  num_of_clusters = 4,
  seed = 246,
  subset = 1,
  algorithm = "Hartigan-Wong",
  iter.max = 1000
)

Arguments

task

an mlr3 task for binary classification

trained_model

an mlr3 trained learner object

splits

an mlr3 object defining data splits for train and test sets

shap_Mean_wide

the data frame of SHAP values in wide format from eSHAP_plot.R

shap_Mean_long

the data frame of SHAP values in long format from eSHAP_plot.R

num_of_clusters

number of clusters to make based on SHAP values, default: 4

seed

an integer for reproducibility, Default to 246

subset

what percentage of the instances to use from 0 to 1 where 1 means all

algorithm

k-means algorithm character: "Hartigan-Wong", "Lloyd", "Forgy", "MacQueen".

iter.max

maximum number of iterations allowed

Value

A list containing four elements:

shap_plot_onerow

An interactive plot displaying the SHAP values for each feature, clustered by the specified number of clusters. Each cluster is shown in a facet.

combined_plot

A ggplot2 figure combining confusion matrices for each cluster, providing insights into the model's performance within each identified subgroup.

kmeans_fvals_desc

A summary table containing statistical descriptions of the clusters based on feature values.

shap_Mean_wide_kmeans

A data frame containing clustered SHAP values along with predictions and ground truth information.

kmeans_info

Information about the k-means clustering process, including cluster centers and assignment details.

References

Zargari Marandi, R., 2024. ExplaineR: an R package to explain machine learning models. Bioinformatics advances, 4(1), p.vbae049, https://doi.org/10.1093/bioadv/vbae049.

See Also

Other functions to visualize and interpret machine learning models: eSHAP_plot.

Examples


library("explainer")
seed <- 246
set.seed(seed)
# Load necessary packages
if (!requireNamespace("mlbench", quietly = TRUE)) stop("mlbench not installed.")
if (!requireNamespace("mlr3learners", quietly = TRUE)) stop("mlr3learners not installed.")
if (!requireNamespace("ranger", quietly = TRUE)) stop("ranger not installed.")
# Load BreastCancer dataset
utils::data("BreastCancer", package = "mlbench")
target_col <- "Class"
positive_class <- "malignant"
mydata <- BreastCancer[, -1]
mydata <- na.omit(mydata)
sex <- sample(
  c("Male", "Female"),
  size = nrow(mydata),
  replace = TRUE
)
mydata$age <- as.numeric(sample(
  seq(18,60),
  size = nrow(mydata),
  replace = TRUE
))
mydata$sex <- factor(
  sex,
  levels = c("Male", "Female"),
  labels = c(1, 0)
)
maintask <- mlr3::TaskClassif$new(
  id = "my_classification_task",
  backend = mydata,
  target = target_col,
  positive = positive_class
)
splits <- mlr3::partition(maintask)
mylrn <- mlr3::lrn(
  "classif.ranger",
  predict_type = "prob"
)
mylrn$train(maintask, splits$train)
SHAP_output <- eSHAP_plot(
  task = maintask,
  trained_model = mylrn,
  splits = splits,
  sample.size = 2, # also 30 or more
  seed = seed,
  subset = 0.02 # up to 1
)
shap_Mean_wide <- SHAP_output[[2]]
shap_Mean_long <- SHAP_output[[3]]
SHAP_plot_clusters <- SHAPclust(
  task = maintask,
  trained_model = mylrn,
  splits = splits,
  shap_Mean_wide = shap_Mean_wide,
  shap_Mean_long = shap_Mean_long,
  num_of_clusters = 3, # your choice
  seed = seed,
  subset = 0.02, # match with eSHAP_plot
  algorithm="Hartigan-Wong",
  iter.max = 10
)



[Package explainer version 1.0.1 Index]