HessFeaturePlot {NeuralSens}R Documentation

Feature sensitivity plot

Description

Show the distribution of the sensitivities of the output in geom_sina() plot which color depends on the input values

Usage

HessFeaturePlot(object, fdata = NULL, ...)

Arguments

object

fitted neural network model or array containing the raw sensitivities from the function SensAnalysisMLP

fdata

data.frame containing the data to evaluate the sensitivity of the model. Not needed if the raw sensitivities has been passed as object

...

further arguments that should be passed to SensAnalysisMLP function

Value

list of Feature sensitivity plot as described in https://www.r-bloggers.com/2019/03/a-gentle-introduction-to-shap-values-in-r/

Examples

## Load data -------------------------------------------------------------------
data("DAILY_DEMAND_TR")
fdata <- DAILY_DEMAND_TR

## Parameters of the NNET ------------------------------------------------------
hidden_neurons <- 5
iters <- 250
decay <- 0.1

################################################################################
#########################  REGRESSION NNET #####################################
################################################################################
## Regression dataframe --------------------------------------------------------
# Scale the data
fdata.Reg.tr <- fdata[,2:ncol(fdata)]
fdata.Reg.tr[,3] <- fdata.Reg.tr[,3]/10
fdata.Reg.tr[,1] <- fdata.Reg.tr[,1]/1000

# Normalize the data for some models
preProc <- caret::preProcess(fdata.Reg.tr, method = c("center","scale"))
nntrData <- predict(preProc, fdata.Reg.tr)

#' ## TRAIN nnet NNET --------------------------------------------------------
# Create a formula to train NNET
form <- paste(names(fdata.Reg.tr)[2:ncol(fdata.Reg.tr)], collapse = " + ")
form <- formula(paste(names(fdata.Reg.tr)[1], form, sep = " ~ "))

set.seed(150)
nnetmod <- nnet::nnet(form,
                           data = nntrData,
                           linear.output = TRUE,
                           size = hidden_neurons,
                           decay = decay,
                           maxit = iters)
# Try SensAnalysisMLP
hess <- NeuralSens::HessianMLP(nnetmod, trData = nntrData, plot = FALSE)
NeuralSens::HessFeaturePlot(hess)

[Package NeuralSens version 1.1.3 Index]