mlr_callback_set {mlr3torch}R Documentation

Base Class for Callbacks

Description

Base class from which callbacks should inherit (see section Inheriting). A callback set is a collection of functions that are executed at different stages of the training loop. They can be used to gain more control over the training process of a neural network without having to write everything from scratch.

When used a in torch learner, the CallbackSet is wrapped in a TorchCallback. The latters parameter set represents the arguments of the CallbackSet's ⁠$initialize()⁠ method.

Inheriting

For each available stage (see section Stages) a public method ⁠$on_<stage>()⁠ can be defined. The evaluation context (a ContextTorch) can be accessed via self$ctx, which contains the current state of the training loop. This context is assigned at the beginning of the training loop and removed afterwards. Different stages of a callback can communicate with each other by assigning values to ⁠$self⁠.

State: To be able to store information in the ⁠$model⁠ slot of a LearnerTorch, callbacks support a state API. You can overload the ⁠$state_dict()⁠ public method to define what will be stored in ⁠learner$model$callbacks$<id>⁠ after training finishes. This then also requires to implement a ⁠$load_state_dict(state_dict)⁠ method that defines how to load a previously saved callback state into a different callback. Note that the ⁠$state_dict()⁠ should not include the parameter values that were used to initialize the callback.

For creating custom callbacks, the function torch_callback() is recommended, which creates a CallbackSet and then wraps it in a TorchCallback. To create a CallbackSet the convenience function callback_set() can be used. These functions perform checks such as that the stages are not accidentally misspelled.

Stages

Terminate Training

If training is to be stopped, it is possible to set the field ⁠$terminate⁠ of ContextTorch. At the end of every epoch this field is checked and if it is TRUE, training stops. This can for example be used to implement custom early stopping.

Public fields

ctx

(ContextTorch or NULL)
The evaluation context for the callback. This field should always be NULL except during the ⁠$train()⁠ call of the torch learner.

Active bindings

stages

(character())
The active stages of this callback set.

Methods

Public methods


Method print()

Prints the object.

Usage
CallbackSet$print(...)
Arguments
...

(any)
Currently unused.


Method state_dict()

Returns information that is kept in the the LearnerTorch's state after training. This information should be loadable into the callback using ⁠$load_state_dict()⁠ to be able to continue training. This returns NULL by default.

Usage
CallbackSet$state_dict()

Method load_state_dict()

Loads the state dict into the callback to continue training.

Usage
CallbackSet$load_state_dict(state_dict)
Arguments
state_dict

(any)
The state dict as retrieved via ⁠$state_dict()⁠.


Method clone()

The objects of this class are cloneable with this method.

Usage
CallbackSet$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

See Also

Other Callback: TorchCallback, as_torch_callback(), as_torch_callbacks(), callback_set(), mlr3torch_callbacks, mlr_callback_set.checkpoint, mlr_callback_set.progress, mlr_context_torch, t_clbk(), torch_callback()


[Package mlr3torch version 0.1.0 Index]