luz_callback_auto_resume {luz} | R Documentation |
Resume training callback
Description
This callback allows you to resume training a model.
Usage
luz_callback_auto_resume(path = "./state.pt")
Arguments
path |
Path to save state files for the model. |
Details
When using it, model weights, optimizer state are serialized at the end of each epoch. If something fails during training simply re-running the same script will restart the model training from the epoch right after the last epoch that was serialized.
Customizing serialization
By default model, optimizer state and records are serialized. Callbacks can
be used to customize serialization by implementing the state_dict()
and
load_state_dict()
methods.
If those methods are implemented, then state_dict()
is called at the end of
each epoch and load_state_dict()
is called when the model is resumed.
Note
In general you will want to add this callback as the last in the callbacks
list, this way, the serialized state is likely to contain all possible changes
that other callbacks could have made at 'on_epoch_end'
. The default weight
attribute of this callback is Inf
.
Read the checkpointing article in the pkgdown website for more information.
See Also
Other luz_callbacks:
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
Examples
if (torch::torch_is_installed()) {
library(torch)
library(luz)
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
# simulate a failure in the middle of epoch 5 happening only once.
callback_stop <- luz_callback(
"interrupt",
failed = FALSE,
on_epoch_end = function() {
if (ctx$epoch == 5 && !self$failed) {
self$failed <- TRUE
stop("Error on epoch 5")
}
}
)
path <- tempfile()
autoresume <- luz_callback_auto_resume(path = path)
interrupt <- callback_stop()
# try once and the model fails
try({
results <- model %>% fit(
list(x, y),
callbacks = list(autoresume, interrupt),
verbose = FALSE
)
})
# model resumes and completes
results <- model %>% fit(
list(x, y),
callbacks = list(autoresume, interrupt),
verbose = FALSE
)
get_metrics(results)
}