train_nn {simpleMLP} | R Documentation |
Train Network
Description
Train the network with specified hyperparameters and return the trained model.
Usage
train_nn(
train_data,
train_target,
validate_data,
validate_target,
model,
alpha,
epochs,
batch_size = nrow(train_data),
plot_acc = TRUE
)
Arguments
train_data |
set of training data |
train_target |
set of training data targets in one-hot encoded form |
validate_data |
set of validation data targets in one-hot encoded form |
validate_target |
set of targets in |
model |
list of weights and biases |
alpha |
learning rate |
epochs |
number of epochs |
batch_size |
mini-batch size |
plot_acc |
whether or not to plot training and validation accuracy |
Value
list of weights and biases after training
Examples
## Not run:
mlp_model <- init_nn(784, 100, 50, 10)
mnist <- load_mnist()
train_data <- mnist[1]
train_target <- mnist[2]
validate_data <- mnist[3]
validate_target <- mnist[4]
mlp_model <- train_nn(train_data, train_target, validate_data,
validate_target, mlp_model, 0.01, 1, 64)
## End(Not run)
[Package simpleMLP version 1.0.0 Index]