HessDotPlot {NeuralSens} | R Documentation |
Second derivatives 3D scatter or surface plot against input values
Description
3D Plot of second derivatives of the neural network output respect
to the inputs. This function use plotly
instead of ggplot2
to
achieve better visualization
Usage
HessDotPlot(
object,
fdata = NULL,
input_vars = "all",
input_vars2 = "all",
output_vars = "all",
surface = FALSE,
grid = FALSE,
color = NULL,
...
)
Arguments
object |
fitted neural network model or |
fdata |
|
input_vars |
|
input_vars2 |
|
output_vars |
|
surface |
|
grid |
|
color |
|
... |
further arguments that should be passed to |
Value
list of 3D geom_point
plots for the inputs variables representing the
sensitivity of each output respect to the inputs
Examples
## Load data -------------------------------------------------------------------
data("DAILY_DEMAND_TR")
fdata <- DAILY_DEMAND_TR
## Parameters of the NNET ------------------------------------------------------
hidden_neurons <- 5
iters <- 250
decay <- 0.1
################################################################################
######################### REGRESSION NNET #####################################
################################################################################
## Regression dataframe --------------------------------------------------------
# Scale the data
fdata.Reg.tr <- fdata[,2:ncol(fdata)]
fdata.Reg.tr[,3] <- fdata.Reg.tr[,3]/10
fdata.Reg.tr[,1] <- fdata.Reg.tr[,1]/1000
# Normalize the data for some models
preProc <- caret::preProcess(fdata.Reg.tr, method = c("center","scale"))
nntrData <- predict(preProc, fdata.Reg.tr)
#' ## TRAIN nnet NNET --------------------------------------------------------
# Create a formula to train NNET
form <- paste(names(fdata.Reg.tr)[2:ncol(fdata.Reg.tr)], collapse = " + ")
form <- formula(paste(names(fdata.Reg.tr)[1], form, sep = " ~ "))
set.seed(150)
nnetmod <- nnet::nnet(form,
data = nntrData,
linear.output = TRUE,
size = hidden_neurons,
decay = decay,
maxit = iters)
# Try HessDotPlot
NeuralSens::HessDotPlot(nnetmod, fdata = nntrData, surface = TRUE, color = "WD")