NNtrainPredict {NNbenchmark} | R Documentation |
Generic Functions for Training and Predicting
Description
An implementation with do.call
so that any neural network function that fits
the format can be tested.
In trainPredict_1mth1data
, a neural network is trained on one dataset
and then used for predictions, with several functionalities. Then, the performance
of the neural network is summarized.
trainPredict_1data
serves as a wrapper function for trainPredict_1mth1data
for multiple methods.
trainPredict_1pkg
serves as a wrapper function for trainPredict_1mth1data
for multiple datasets.
Usage
trainPredict_1mth1data(dset, method, trainFUN, hyperparamFUN, predictFUN,
summaryFUN, prepareZZ.arg = list(), nrep = 5, doplot = FALSE,
plot.arg = list(col1 = 1:nrep, lwd1 = 1, col2 = 4, lwd2 = 3), pkgname,
pkgfun, csvfile = FALSE, rdafile = FALSE, odir = ".", echo = FALSE,
echoreport = FALSE, appendcsv = TRUE, ...)
trainPredict_1data(dset, methodlist, trainFUN, hyperparamFUN, predictFUN,
summaryFUN, closeFUN, startNN = NA, prepareZZ.arg = list(), nrep = 5,
doplot = FALSE, plot.arg = list(), pkgname = "pkg", pkgfun = "train",
csvfile = FALSE, rdafile = FALSE, odir = ".", echo = FALSE, ...)
trainPredict_1pkg(dsetnum, pkgname = "pkg", pkgfun = "train", methodvect,
prepareZZ.arg = list(), summaryFUN, nrep = 5, doplot = FALSE,
plot.arg = list(), csvfile = FALSE, rdafile = FALSE, odir = ".",
echo = FALSE, appendcsv = TRUE, ...)
Arguments
dset |
a number or string indicating which dataset to use, see |
method |
a method for a particular function |
trainFUN |
the training function used |
hyperparamFUN |
the function resulting in parameters needed for training |
predictFUN |
the prediction function used |
summaryFUN |
measure performance by observed and predicted y values, |
prepareZZ.arg |
list of arguments for |
nrep |
a number for how many times a neural network should be trained with a package/function |
doplot |
logical value, TRUE executes plots and FALSE does not |
plot.arg |
list of arguments for plots |
pkgname |
package name |
pkgfun |
name of the package function to train neural network |
csvfile |
logical value, adds summary to csv files per dataset if TRUE |
rdafile |
logical value, outputs rdafile of predictions and summary if TRUE |
odir |
output directory |
echo |
logical value, separates training between packages with some text and enables echoreport if TRUE |
echoreport |
logical value, detailed reports are printed (such as model summaries and str(data)) if TRUE, will not work if echo is FALSE |
appendcsv |
logical value, if |
... |
additional arguments |
methodlist |
list of methods per package/function |
closeFUN |
a function to detach packages or other necessary environment clearing |
startNN |
a function to start needed outside libraries, for example, h2o |
dsetnum |
a vector of numbers indicating which dataset to use in |
methodvect |
vector of methods per package/function |
Value
An array with values as in NNsummary including each repetition, with options for plots and output files
Examples
nrep <- 2
odir <- tempdir()
### Package with one method/optimization algorithm
library("brnn")
brnn.method <- "gaussNewton"
hyperParams.brnn <- function(optim_method, ...) {
return(list(iter = 200))
}
brnn.prepareZZ <- list(xdmv = "m", ydmv = "v", zdm = "d", scale = TRUE)
NNtrain.brnn <- function(x, y, dataxy, formula, neur, optim_method, hyperParams,...) {
hyper_params <- do.call(hyperParams.brnn, list(brnn.method))
iter <- hyper_params$iter
NNreg <- brnn::brnn(x, y, neur, normalize = FALSE, epochs = iter, verbose = FALSE)
return(NNreg)
}
NNpredict.brnn <- function(object, x, ...) { predict(object, x) }
NNclose.brnn <- function(){
if("package:brnn" %in% search())
detach("package:brnn", unload=TRUE)
}
res <- trainPredict_1pkg(1:2, pkgname = "brnn", pkgfun = "brnn", brnn.method,
prepareZZ.arg = brnn.prepareZZ, nrep = nrep, doplot = TRUE,
csvfile = FALSE, rdafile = FALSE, odir = odir, echo = FALSE)
### Package with more than one method/optimization algorithm
library(validann)
validann.method <- c("Nelder-Mead", "BFGS", "CG", "L-BFGS-B", "SANN")
hyperParams.validann <- function(optim_method, ...) {
if(optim_method == "Nelder-Mead") { maxiter <- 10000 }
if(optim_method == "BFGS") { maxiter <- 200 }
if(optim_method == "CG") { maxiter <- 1000 }
if(optim_method == "L-BFGS-B") { maxiter <- 200 }
if(optim_method == "SANN") { maxiter <- 1000 }
return(list(iter = maxiter, method = optim_method, params))
}
validann.prepareZZ <- list(xdmv = "m", ydmv = "m", zdm = "d", scale = TRUE)
NNtrain.validann <- function(x, y, dataxy, formula, neur, optim_method, hyperParams, ...) {
hyper_params <- do.call(hyperParams, list(optim_method, ...))
iter <- hyper_params$iter
method <- hyper_params$method
NNreg <- validann::ann(x, y, size = neur, method = method, maxit = iter)
return (NNreg)
}
NNpredict.validann <- function(object, x, ...) { predict(object, x) }
NNclose.validann <- function() {
if("package:validann" %in% search())
detach("package:validann", unload=TRUE)
}
res <- trainPredict_1pkg(1:2, pkgname = "validann", pkgfun = "ann", validann.method,
repareZZ.arg = validann.prepareZZ, nrep = nrep, doplot = FALSE,
csvfile = TRUE, rdafile = TRUE, odir = odir, echo = FALSE)