train {deepNN} | R Documentation |
train function
Description
A function to train a neural network defined using the network function.
Usage
train(
dat,
truth,
net,
loss = Qloss(),
tol = 0.95,
eps = 0.001,
batchsize = NULL,
dropout = dropoutProbs(),
parinit = function(n) {
return(runif(n, -0.01, 0.01))
},
monitor = TRUE,
stopping = "default",
update = "classification"
)
Arguments
dat |
the input data, a list of vectors |
truth |
the truth, a list of vectors to compare with output from the feed-forward network |
net |
an object of class network, see ?network |
loss |
the loss function, see ?Qloss and ?multinomial |
tol |
stopping criteria for training. Current method monitors the quality of randomly chosen predictions from the data, terminates when the mean predictive probabilities of the last 20 randomly chosen points exceeds tol, default is 0.95 |
eps |
stepsize scaling constant in gradient descent, or stochastic gradient descent |
batchsize |
size of minibatches to be used with stochastic gradient descent |
dropout |
optional list of dropout probabilities ?dropoutProbs |
parinit |
a function of a single parameter returning the initial distribution of the weights, default is uniform on (-0.01,0.01) |
monitor |
logical, whether to produce learning/convergence diagnostic plots |
stopping |
method for stopping computation default, 'default', calls the function stopping.default |
update |
and default for meth is 'classification', which calls updateStopping.classification |
Value
optimal cost and parameters from the trained network; at present, diagnostic plots are produced illustrating the parameters of the model, the gradient and stopping criteria trace.
References
Ian Goodfellow, Yoshua Bengio, Aaron Courville, Francis Bach. Deep Learning. (2016)
Terrence J. Sejnowski. The Deep Learning Revolution (The MIT Press). (2018)
Neural Networks YouTube playlist by 3brown1blue: https://www.youtube.com/playlist?list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi
http://neuralnetworksanddeeplearning.com/
See Also
network, train, backprop_evaluate, MLP_net, backpropagation_MLP, logistic, ReLU, smoothReLU, ident, softmax, Qloss, multinomial, NNgrad_test, weights2list, bias2list, biasInit, memInit, gradInit, addGrad, nnetpar, nbiaspar, addList, no_regularisation, L1_regularisation, L2_regularisation
Examples
# Example 1 - mnist data
# See example at mnist repository under user bentaylor1 on githib
# Example 2
N <- 1000
d <- matrix(rnorm(5*N),ncol=5)
fun <- function(x){
lp <- 2*x[2]
pr <- exp(lp) / (1 + exp(lp))
ret <- c(0,0)
ret[1+rbinom(1,1,pr)] <- 1
return(ret)
}
d <- lapply(1:N,function(i){return(d[i,])})
truth <- lapply(d,fun)
net <- network( dims = c(5,10,2),
activ=list(ReLU(),softmax()))
netwts <- train( dat=d,
truth=truth,
net=net,
eps=0.01,
tol=100, # run for 100 iterations
batchsize=10, # note this is not enough
loss=multinomial(), # for convergence
stopping="maxit")
pred <- NNpredict( net=net,
param=netwts$opt,
newdata=d,
newtruth=truth,
record=TRUE,
plot=TRUE)