glob_nn_vis {viralx} | R Documentation |
Global Visualization of SHAP Values for Neural Network Model
Description
The glob_nn_vis
function generates a global visualization of SHAP (Shapley
Additive Explanations) values for a neural network model. It utilizes the
DALEXtra package to explain the model's predictions and then creates a global
SHAP visualization.
Usage
glob_nn_vis(vip_featured, hiv_data, hu, plty, epo, vip_train, v_train)
Arguments
vip_featured |
A character value specifying the featured variable of interest. |
hiv_data |
A data frame containing the HIV research data used for model training. |
hu |
A numeric value specifying the number of hidden units in the neural network model. |
plty |
A numeric value specifying the penalty parameter for the neural network model. |
epo |
A numeric value specifying the number of epochs (training iterations) for the neural network model. |
vip_train |
A data frame containing the training data used to fit the neural network model. |
v_train |
A numeric vector representing the response variable corresponding to the training data. |
Value
A global visualization of SHAP values for the specified neural network model.
Examples
library(dplyr)
library(rsample)
cd_2019 <- c(824, 169, 342, 423, 441, 507, 559,
173, 764, 780, 244, 527, 417, 800,
602, 494, 345, 780, 780, 527, 556,
559, 238, 288, 244, 353, 169, 556,
824, 169, 342, 423, 441, 507, 559)
vl_2019 <- c(40, 11388, 38961, 40, 75, 4095, 103,
11388, 46, 103, 11388, 40, 0, 11388,
0, 4095, 40, 93, 49, 49, 49,
4095, 6837, 38961, 38961, 0, 0, 93,
40, 11388, 38961, 40, 75, 4095, 103)
cd_2021 <- c(992, 275, 331, 454, 479, 553, 496,
230, 605, 432, 170, 670, 238, 238,
634, 422, 429, 513, 327, 465, 479,
661, 382, 364, 109, 398, 209, 1960,
992, 275, 331, 454, 479, 553, 496)
vl_2021 <- c(80, 1690, 5113, 71, 289, 3063, 0,
262, 0, 15089, 13016, 1513, 60, 60,
49248, 159308, 56, 0, 516675, 49, 237,
84, 292, 414, 26176, 62, 126, 93,
80, 1690, 5113, 71, 289, 3063, 0)
cd_2022 <- c(700, 127, 127, 547, 547, 547, 777,
149, 628, 614, 253, 918, 326, 326,
574, 361, 253, 726, 659, 596, 427,
447, 326, 253, 248, 326, 260, 918,
700, 127, 127, 547, 547, 547, 777)
vl_2022 <- c(0, 0, 53250, 0, 40, 1901, 0,
955, 0, 0, 0, 0, 40, 0,
49248, 159308, 56, 0, 516675, 49, 237,
0, 23601, 0, 40, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0)
x <- cbind(cd_2019, vl_2019, cd_2021, vl_2021, cd_2022, vl_2022) |>
as.data.frame()
set.seed(123)
hi_data <- rsample::initial_split(x)
set.seed(123)
hiv_data <- hi_data |>
rsample::training()
hu <- 5
plty <- 1.131656e-09
epo <- 176
vip_featured <- c("cd_2022")
vip_features <- c("cd_2019", "vl_2019", "cd_2021", "vl_2021", "vl_2022")
set.seed(123)
vi_train <- rsample::initial_split(x)
set.seed(123)
vip_train <- vi_train |>
rsample::training() |>
dplyr::select(rsample::all_of(vip_features))
v_train <- vi_train |>
rsample::training() |>
dplyr::select(rsample::all_of(vip_featured))
glob_nn_vis(vip_featured, hiv_data, hu, plty, epo, vip_train, v_train)