continue_training {cito}R Documentation

Continues training of a model generated with dnn or cnn for additional epochs.

Description

If the training/validation loss is still decreasing at the end of the training, it is often a sign that the NN has not yet converged. You can use this function to continue training instead of re-training the entire model.

Usage

continue_training(model, ...)

## S3 method for class 'citodnn'
continue_training(
  model,
  epochs = 32,
  data = NULL,
  device = NULL,
  verbose = TRUE,
  changed_params = NULL,
  ...
)

## S3 method for class 'citodnnBootstrap'
continue_training(
  model,
  epochs = 32,
  data = NULL,
  device = NULL,
  verbose = TRUE,
  changed_params = NULL,
  parallel = FALSE,
  ...
)

## S3 method for class 'citocnn'
continue_training(
  model,
  epochs = 32,
  X = NULL,
  Y = NULL,
  device = c("cpu", "cuda", "mps"),
  verbose = TRUE,
  changed_params = NULL,
  ...
)

Arguments

model

a model created by dnn or cnn

...

class-specific arguments

epochs

additional epochs the training should continue for

data

matrix or data.frame. If not provided data from original training will be used

device

can be used to overwrite device used in previous training

verbose

print training and validation loss of epochs

changed_params

list of arguments to change compared to original training setup, see dnn which parameter can be changed

parallel

train bootstrapped model in parallel

X

array. If not provided X from original training will be used

Y

vector, factor, numerical matrix or logical matrix. If not provided Y from original training will be used

Value

a model of class citodnn, citodnnBootstrap or citocnn created by dnn or cnn

Examples


if(torch::torch_is_installed()){
library(cito)

set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)

# Build and train  Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,], epochs = 32)

# continue training for another 32 epochs
nn.fit<- continue_training(nn.fit,epochs = 32)

# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
}


[Package cito version 1.1 Index]