train {ANN2}R Documentation

Continue training of a Neural Network

Description

Continue training of a neural network object returned by neuralnetwork() or autoencoder()

Usage

train(
  object,
  X,
  y = NULL,
  n.epochs = 100,
  batch.size = 32,
  drop.last = TRUE,
  val.prop = 0.1,
  random.seed = NULL
)

Arguments

object

object of class ANN produced by neuralnetwork() or autoencoder()

X

matrix with explanatory variables

y

matrix with dependent variables. Not required if object is an autoencoder

n.epochs

the number of epochs to train. This parameter largely determines the training time (one epoch is a single iteration through the training data).

batch.size

the number of observations to use in each batch. Batch learning is computationally faster than stochastic gradient descent. However, large batches might not result in optimal learning, see Efficient Backprop by Le Cun for details.

drop.last

logical. Only applicable if the size of the training set is not perfectly devisible by the batch size. Determines if the last chosen observations should be discarded (in the current epoch) or should constitute a smaller batch. Note that a smaller batch leads to a noisier approximation of the gradient.

val.prop

proportion of training data to use for tracking the loss on a validation set during training. Useful for assessing the training process and identifying possible overfitting. Set to zero for only tracking the loss on the training data.

random.seed

optional seed for the random number generator

Details

A new validation set is randomly chosen. This can result in irregular jumps in the plot given by plot.ANN().

Value

An ANN object. Use function plot(<object>) to assess loss on training and optionally validation data during training process. Use function predict(<object>, <newdata>) for prediction.

References

LeCun, Yann A., et al. "Efficient backprop." Neural networks: Tricks of the trade. Springer Berlin Heidelberg, 2012. 9-48.

Examples

# Train a neural network on the iris dataset
X <- iris[,1:4]
y <- iris$Species
NN <- neuralnetwork(X, y, hidden.layers = 10, sgd.momentum = 0.9, 
                    learn.rates = 0.01, val.prop = 0.3, n.epochs = 100)

# Plot training and validation loss during training
plot(NN)

# Continue training for 1000 epochs
train(NN, X, y, n.epochs = 200, val.prop = 0.3)

# Again plot the loss - note the jump in the validation loss at the 100th epoch
# This is due to the random selection of a new validation set
plot(NN)

[Package ANN2 version 2.3.4 Index]