continue_training {cito} | R Documentation |
Continues training of a model generated with dnn
or cnn
for additional epochs.
Description
If the training/validation loss is still decreasing at the end of the training, it is often a sign that the NN has not yet converged. You can use this function to continue training instead of re-training the entire model.
Usage
continue_training(model, ...)
## S3 method for class 'citodnn'
continue_training(
model,
epochs = 32,
data = NULL,
device = NULL,
verbose = TRUE,
changed_params = NULL,
...
)
## S3 method for class 'citodnnBootstrap'
continue_training(
model,
epochs = 32,
data = NULL,
device = NULL,
verbose = TRUE,
changed_params = NULL,
parallel = FALSE,
...
)
## S3 method for class 'citocnn'
continue_training(
model,
epochs = 32,
X = NULL,
Y = NULL,
device = c("cpu", "cuda", "mps"),
verbose = TRUE,
changed_params = NULL,
...
)
Arguments
model |
|
... |
class-specific arguments |
epochs |
additional epochs the training should continue for |
data |
matrix or data.frame. If not provided data from original training will be used |
device |
can be used to overwrite device used in previous training |
verbose |
print training and validation loss of epochs |
changed_params |
list of arguments to change compared to original training setup, see |
parallel |
train bootstrapped model in parallel |
X |
array. If not provided X from original training will be used |
Y |
vector, factor, numerical matrix or logical matrix. If not provided Y from original training will be used |
Value
a model of class citodnn, citodnnBootstrap or citocnn created by dnn
or cnn
Examples
if(torch::torch_is_installed()){
library(cito)
set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)
# Build and train Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,], epochs = 32)
# continue training for another 32 epochs
nn.fit<- continue_training(nn.fit,epochs = 32)
# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
}