train {deepNN}R Documentation

train function

Description

A function to train a neural network defined using the network function.

Usage

train(
  dat,
  truth,
  net,
  loss = Qloss(),
  tol = 0.95,
  eps = 0.001,
  batchsize = NULL,
  dropout = dropoutProbs(),
  parinit = function(n) {
     return(runif(n, -0.01, 0.01))
 },
  monitor = TRUE,
  stopping = "default",
  update = "classification"
)

Arguments

dat

the input data, a list of vectors

truth

the truth, a list of vectors to compare with output from the feed-forward network

net

an object of class network, see ?network

loss

the loss function, see ?Qloss and ?multinomial

tol

stopping criteria for training. Current method monitors the quality of randomly chosen predictions from the data, terminates when the mean predictive probabilities of the last 20 randomly chosen points exceeds tol, default is 0.95

eps

stepsize scaling constant in gradient descent, or stochastic gradient descent

batchsize

size of minibatches to be used with stochastic gradient descent

dropout

optional list of dropout probabilities ?dropoutProbs

parinit

a function of a single parameter returning the initial distribution of the weights, default is uniform on (-0.01,0.01)

monitor

logical, whether to produce learning/convergence diagnostic plots

stopping

method for stopping computation default, 'default', calls the function stopping.default

update

and default for meth is 'classification', which calls updateStopping.classification

Value

optimal cost and parameters from the trained network; at present, diagnostic plots are produced illustrating the parameters of the model, the gradient and stopping criteria trace.

References

  1. Ian Goodfellow, Yoshua Bengio, Aaron Courville, Francis Bach. Deep Learning. (2016)

  2. Terrence J. Sejnowski. The Deep Learning Revolution (The MIT Press). (2018)

  3. Neural Networks YouTube playlist by 3brown1blue: https://www.youtube.com/playlist?list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi

  4. http://neuralnetworksanddeeplearning.com/

See Also

network, train, backprop_evaluate, MLP_net, backpropagation_MLP, logistic, ReLU, smoothReLU, ident, softmax, Qloss, multinomial, NNgrad_test, weights2list, bias2list, biasInit, memInit, gradInit, addGrad, nnetpar, nbiaspar, addList, no_regularisation, L1_regularisation, L2_regularisation

Examples


# Example 1 - mnist data

# See example at mnist repository under user bentaylor1 on githib

# Example 2

N <- 1000
d <- matrix(rnorm(5*N),ncol=5)

fun <- function(x){
    lp <- 2*x[2]
    pr <- exp(lp) / (1 + exp(lp))
    ret <- c(0,0)
    ret[1+rbinom(1,1,pr)] <- 1
    return(ret)
}

d <- lapply(1:N,function(i){return(d[i,])})

truth <- lapply(d,fun)

net <- network( dims = c(5,10,2),
                activ=list(ReLU(),softmax()))

netwts <- train( dat=d,
                 truth=truth,
                 net=net,
                 eps=0.01,
                 tol=100,            # run for 100 iterations
                 batchsize=10,       # note this is not enough
                 loss=multinomial(), # for convergence
                 stopping="maxit")

pred <- NNpredict(  net=net,
                    param=netwts$opt,
                    newdata=d,
                    newtruth=truth,
                    record=TRUE,
                    plot=TRUE)


[Package deepNN version 1.2 Index]