mlr_callback_set {mlr3torch} | R Documentation |
Base Class for Callbacks
Description
Base class from which callbacks should inherit (see section Inheriting). A callback set is a collection of functions that are executed at different stages of the training loop. They can be used to gain more control over the training process of a neural network without having to write everything from scratch.
When used a in torch learner, the CallbackSet
is wrapped in a TorchCallback
.
The latters parameter set represents the arguments of the CallbackSet
's $initialize()
method.
Inheriting
For each available stage (see section Stages) a public method $on_<stage>()
can be defined.
The evaluation context (a ContextTorch
) can be accessed via self$ctx
, which contains
the current state of the training loop.
This context is assigned at the beginning of the training loop and removed afterwards.
Different stages of a callback can communicate with each other by assigning values to $self
.
State:
To be able to store information in the $model
slot of a LearnerTorch
, callbacks support a state API.
You can overload the $state_dict()
public method to define what will be stored in learner$model$callbacks$<id>
after training finishes.
This then also requires to implement a $load_state_dict(state_dict)
method that defines how to load a previously saved
callback state into a different callback.
Note that the $state_dict()
should not include the parameter values that were used to initialize the callback.
For creating custom callbacks, the function torch_callback()
is recommended, which creates a
CallbackSet
and then wraps it in a TorchCallback
.
To create a CallbackSet
the convenience function callback_set()
can be used.
These functions perform checks such as that the stages are not accidentally misspelled.
Stages
-
begin
:: Run before the training loop begins. -
epoch_begin
:: Run he beginning of each epoch. -
batch_begin
:: Run before the forward call. -
after_backward
:: Run after the backward call. -
batch_end
:: Run after the optimizer step. -
batch_valid_begin
:: Run before the forward call in the validation loop. -
batch_valid_end
:: Run after the forward call in the validation loop. -
valid_end
:: Run at the end of validation. -
epoch_end
:: Run at the end of each epoch. -
end
:: Run after last epoch. -
exit
:: Run at last, usingon.exit()
.
Terminate Training
If training is to be stopped, it is possible to set the field $terminate
of ContextTorch
.
At the end of every epoch this field is checked and if it is TRUE
, training stops.
This can for example be used to implement custom early stopping.
Public fields
ctx
(
ContextTorch
orNULL
)
The evaluation context for the callback. This field should always beNULL
except during the$train()
call of the torch learner.
Active bindings
stages
(
character()
)
The active stages of this callback set.
Methods
Public methods
Method print()
Prints the object.
Usage
CallbackSet$print(...)
Arguments
...
(any)
Currently unused.
Method state_dict()
Returns information that is kept in the the LearnerTorch
's state after training.
This information should be loadable into the callback using $load_state_dict()
to be able to continue training.
This returns NULL
by default.
Usage
CallbackSet$state_dict()
Method load_state_dict()
Loads the state dict into the callback to continue training.
Usage
CallbackSet$load_state_dict(state_dict)
Arguments
state_dict
(any)
The state dict as retrieved via$state_dict()
.
Method clone()
The objects of this class are cloneable with this method.
Usage
CallbackSet$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
See Also
Other Callback:
TorchCallback
,
as_torch_callback()
,
as_torch_callbacks()
,
callback_set()
,
mlr3torch_callbacks
,
mlr_callback_set.checkpoint
,
mlr_callback_set.progress
,
mlr_context_torch
,
t_clbk()
,
torch_callback()