train_nn {simpleMLP}R Documentation

Train Network

Description

Train the network with specified hyperparameters and return the trained model.

Usage

train_nn(
  train_data,
  train_target,
  validate_data,
  validate_target,
  model,
  alpha,
  epochs,
  batch_size = nrow(train_data),
  plot_acc = TRUE
)

Arguments

train_data

set of training data

train_target

set of training data targets in one-hot encoded form

validate_data

set of validation data targets in one-hot encoded form

validate_target

set of targets in

model

list of weights and biases

alpha

learning rate

epochs

number of epochs

batch_size

mini-batch size

plot_acc

whether or not to plot training and validation accuracy

Value

list of weights and biases after training

Examples

## Not run: 
mlp_model <- init_nn(784, 100, 50, 10)
mnist <- load_mnist()
train_data <- mnist[1]
train_target <- mnist[2]
validate_data <- mnist[3]
validate_target <- mnist[4]
mlp_model <- train_nn(train_data, train_target, validate_data,
validate_target, mlp_model, 0.01, 1, 64)

## End(Not run)

[Package simpleMLP version 1.0.0 Index]