train {ANN2} | R Documentation |
Continue training of a neural network object returned by neuralnetwork()
or autoencoder()
train(
object,
X,
y = NULL,
n.epochs = 100,
batch.size = 32,
drop.last = TRUE,
val.prop = 0.1,
random.seed = NULL
)
object |
object of class |
X |
matrix with explanatory variables |
y |
matrix with dependent variables. Not required if object is an autoencoder |
n.epochs |
the number of epochs to train. This parameter largely determines the training time (one epoch is a single iteration through the training data). |
batch.size |
the number of observations to use in each batch. Batch learning is computationally faster than stochastic gradient descent. However, large batches might not result in optimal learning, see Efficient Backprop by Le Cun for details. |
drop.last |
logical. Only applicable if the size of the training set is not perfectly devisible by the batch size. Determines if the last chosen observations should be discarded (in the current epoch) or should constitute a smaller batch. Note that a smaller batch leads to a noisier approximation of the gradient. |
val.prop |
proportion of training data to use for tracking the loss on a validation set during training. Useful for assessing the training process and identifying possible overfitting. Set to zero for only tracking the loss on the training data. |
random.seed |
optional seed for the random number generator |
A new validation set is randomly chosen. This can result in irregular jumps
in the plot given by plot.ANN()
.
An ANN
object. Use function plot(<object>)
to assess
loss on training and optionally validation data during training process. Use
function predict(<object>, <newdata>)
for prediction.
LeCun, Yann A., et al. "Efficient backprop." Neural networks: Tricks of the trade. Springer Berlin Heidelberg, 2012. 9-48.
# Train a neural network on the iris dataset
X <- iris[,1:4]
y <- iris$Species
NN <- neuralnetwork(X, y, hidden.layers = 10, sgd.momentum = 0.9,
learn.rates = 0.01, val.prop = 0.3, n.epochs = 100)
# Plot training and validation loss during training
plot(NN)
# Continue training for 1000 epochs
train(NN, X, y, n.epochs = 200, val.prop = 0.3)
# Again plot the loss - note the jump in the validation loss at the 100th epoch
# This is due to the random selection of a new validation set
plot(NN)