cnn {cito}R Documentation

CNN

Description

fits a custom convolutional neural network.

Usage

cnn(
  X,
  Y = NULL,
  architecture,
  loss = c("mse", "mae", "softmax", "cross-entropy", "gaussian", "binomial", "poisson"),
  optimizer = c("sgd", "adam", "adadelta", "adagrad", "rmsprop", "rprop"),
  lr = 0.01,
  alpha = 0.5,
  lambda = 0,
  validation = 0,
  batchsize = 32L,
  burnin = 10,
  shuffle = TRUE,
  epochs = 100,
  early_stopping = NULL,
  lr_scheduler = NULL,
  custom_parameters = NULL,
  device = c("cpu", "cuda", "mps"),
  plot = TRUE,
  verbose = TRUE
)

Arguments

X

predictor: array with dimension 3, 4 or 5 for 1D-, 2D- or 3D-convolutions, respectively. The first dimension are the samples, the second dimension the channels and the third - fifth dimension are the input dimensions

Y

response: vector, factor, numerical matrix or logical matrix

architecture

'citoarchitecture' object created by create_architecture

loss

loss after which network should be optimized. Can also be distribution from the stats package or own function, see details

optimizer

which optimizer used for training the network, for more adjustments to optimizer see config_optimizer

lr

learning rate given to optimizer

alpha

add L1/L2 regularization to training (1 - \alpha) * |weights| + \alpha ||weights||^2 will get added for each layer. Must be between 0 and 1

lambda

strength of regularization: lambda penalty, \lambda * (L1 + L2) (see alpha)

validation

percentage of data set that should be taken as validation set (chosen randomly)

batchsize

number of samples that are used to calculate one learning rate step

burnin

training is aborted if the trainings loss is not below the baseline loss after burnin epochs

shuffle

if TRUE, data in each batch gets reshuffled every epoch

epochs

epochs the training goes on for

early_stopping

if set to integer, training will stop if loss has gotten higher for defined number of epochs in a row, will use validation loss if available.

lr_scheduler

learning rate scheduler created with config_lr_scheduler

custom_parameters

List of parameters/variables to be optimized. Can be used in a custom loss function. See Vignette for example.

device

device on which network should be trained on.

plot

plot training loss

verbose

print training and validation loss of epochs

Value

an S3 object of class "citocnn" is returned. It is a list containing everything there is to know about the model and its training process. The list consists of the following attributes:

net

An object of class "nn_sequential" "nn_module", originates from the torch package and represents the core object of this workflow.

call

The original function call

loss

A list which contains relevant information for the target variable and the used loss function

data

Contains data used for training the model

weights

List of weights for each training epoch

use_model_epoch

Integer, which defines which model from which training epoch should be used for prediction.

loaded_model_epoch

Integer, shows which model from which epoch is loaded currently into model$net.

model_properties

A list of properties of the neural network, contains number of input nodes, number of output nodes, size of hidden layers, activation functions, whether bias is included and if dropout layers are included.

training_properties

A list of all training parameters that were used the last time the model was trained. It consists of learning rate, information about an learning rate scheduler, information about the optimizer, number of epochs, whether early stopping was used, if plot was active, lambda and alpha for L1/L2 regularization, batchsize, shuffle, was the data set split into validation and training, which formula was used for training and at which epoch did the training stop.

losses

A data.frame containing training and validation losses of each epoch

Convolutional neural networks:

Convolutional Neural Networks (CNNs) are a specialized type of neural network designed for processing structured grid data, such as images. The characterizing parts of the architecture are convolutional layers, pooling layers and fully-connected (linear) layers:

Loss functions / Likelihoods

We support loss functions and likelihoods for different tasks:

Name Explanation Example / Task
mse mean squared error Regression, predicting continuous values
mae mean absolute error Regression, predicting continuous values
softmax categorical cross entropy Multi-class, species classification
cross-entropy categorical cross entropy Multi-class, species classification
gaussian Normal likelihood Regression, residual error is also estimated (similar to stats::lm())
binomial Binomial likelihood Classification/Logistic regression, mortality
Poisson Poisson likelihood Regression, count data, e.g. species abundances

Training and convergence of neural networks

Ensuring convergence can be tricky when training neural networks. Their training is sensitive to a combination of the learning rate (how much the weights are updated in each optimization step), the batch size (a random subset of the data is used in each optimization step), and the number of epochs (number of optimization steps). Typically, the learning rate should be decreased with the size of the neural networks (amount of learnable parameters). We provide a baseline loss (intercept only model) that can give hints about an appropriate learning rate:

Learning rates

If the training loss of the model doesn't fall below the baseline loss, the learning rate is either too high or too low. If this happens, try higher and lower learning rates.

A common strategy is to try (manually) a few different learning rates to see if the learning rate is on the right scale.

See the troubleshooting vignette (vignette("B-Training_neural_networks")) for more help on training and debugging neural networks.

Finding the right architecture

As with the learning rate, there is no definitive guide to choosing the right architecture for the right task. However, there are some general rules/recommendations: In general, wider, and deeper neural networks can improve generalization - but this is a double-edged sword because it also increases the risk of overfitting. So, if you increase the width and depth of the network, you should also add regularization (e.g., by increasing the lambda parameter, which corresponds to the regularization strength). Furthermore, in Pichler & Hartig, 2023, we investigated the effects of the hyperparameters on the prediction performance as a function of the data size. For example, we found that the selu activation function outperforms relu for small data sizes (<100 observations).

We recommend starting with moderate sizes (like the defaults), and if the model doesn't generalize/converge, try larger networks along with a regularization that helps minimize the risk of overfitting (see vignette("B-Training_neural_networks") ).

Overfitting

Overfitting means that the model fits the training data well, but generalizes poorly to new observations. We can use the validation argument to detect overfitting. If the validation loss starts to increase again at a certain point, it often means that the models are starting to overfit your training data:

Overfitting

Solutions:

Regularization

Elastic Net regularization combines the strengths of L1 (Lasso) and L2 (Ridge) regularization. It introduces a penalty term that encourages sparse weight values while maintaining overall weight shrinkage. By controlling the sparsity of the learned model, Elastic Net regularization helps avoid overfitting while allowing for meaningful feature selection. We advise using elastic net (e.g. lambda = 0.001 and alpha = 0.2).

Dropout regularization helps prevent overfitting by randomly disabling a portion of neurons during training. This technique encourages the network to learn more robust and generalized representations, as it prevents individual neurons from relying too heavily on specific input patterns. Dropout has been widely adopted as a simple yet effective regularization method in deep learning. In the case of 2D and 3D inputs whole feature maps are disabled. Since the torch package doesn't currently support feature map-wise dropout for 1D inputs, instead random neurons in the feature maps are disabled similar to dropout in linear layers.

By utilizing these regularization methods in your neural network training with the cito package, you can improve generalization performance and enhance the network's ability to handle unseen data. These techniques act as valuable tools in mitigating overfitting and promoting more robust and reliable model performance.

Custom Optimizer and Learning Rate Schedulers

When training a network, you have the flexibility to customize the optimizer settings and learning rate scheduler to optimize the learning process. In the cito package, you can initialize these configurations using the config_lr_scheduler and config_optimizer functions.

config_lr_scheduler allows you to define a specific learning rate scheduler that controls how the learning rate changes over time during training. This is beneficial in scenarios where you want to adaptively adjust the learning rate to improve convergence or avoid getting stuck in local optima.

Similarly, the config_optimizer function enables you to specify the optimizer for your network. Different optimizers, such as stochastic gradient descent (SGD), Adam, or RMSprop, offer various strategies for updating the network's weights and biases during training. Choosing the right optimizer can significantly impact the training process and the final performance of your neural network.

Training on graphic cards

If you have an NVIDIA CUDA-enabled device and have installed the CUDA toolkit version 11.3 and cuDNN 8.4, you can take advantage of GPU acceleration for training your neural networks. It is crucial to have these specific versions installed, as other versions may not be compatible. For detailed installation instructions and more information on utilizing GPUs for training, please refer to the mlverse: 'torch' documentation.

Note: GPU training is optional, and the package can still be used for training on CPU even without CUDA and cuDNN installations.

Author(s)

Armin Schenk, Maximilian Pichler

See Also

predict.citocnn, plot.citocnn, coef.citocnn, print.citocnn, summary.citocnn, continue_training, analyze_training


[Package cito version 1.1 Index]