| TorchCallback {mlr3torch} | R Documentation |
Torch Callback
Description
This wraps a CallbackSet and annotates it with metadata, most importantly a ParamSet.
The callback is created for the given parameter values by calling the $generate() method.
This class is usually used to configure the callback of a torch learner, e.g. when constructing
a learner of in a ModelDescriptor.
For a list of available callbacks, see mlr3torch_callbacks.
To conveniently retrieve a TorchCallback, use t_clbk().
Parameters
Defined by the constructor argument param_set.
If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
for each argument of the wrapped loss function, where the parametes are then of type ParamUty.
Super class
mlr3torch::TorchDescriptor -> TorchCallback
Methods
Public methods
Inherited methods
Method new()
Creates a new instance of this R6 class.
Usage
TorchCallback$new( callback_generator, param_set = NULL, id = NULL, label = NULL, packages = NULL, man = NULL )
Arguments
callback_generator(
R6ClassGenerator)
The class generator for the callback that is being wrapped.param_set(
ParamSetorNULL)
The parameter set. IfNULL(default) it is inferred fromcallback_generator.id(
character(1))
The id for of the new object.label(
character(1))
Label for the new instance.packages(
character())
The R packages this object depends on.man(
character(1))
String in the format[pkg]::[topic]pointing to a manual page for this object. The referenced help package can be opened via method$help().
Method clone()
The objects of this class are cloneable with this method.
Usage
TorchCallback$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
See Also
Other Callback:
as_torch_callback(),
as_torch_callbacks(),
callback_set(),
mlr3torch_callbacks,
mlr_callback_set,
mlr_callback_set.checkpoint,
mlr_callback_set.progress,
mlr_context_torch,
t_clbk(),
torch_callback()
Other Torch Descriptor:
TorchDescriptor,
TorchLoss,
TorchOptimizer,
as_torch_callbacks(),
as_torch_loss(),
as_torch_optimizer(),
mlr3torch_losses,
mlr3torch_optimizers,
t_clbk(),
t_loss(),
t_opt()
Examples
# Create a new torch callback from an existing callback set
torch_callback = TorchCallback$new(CallbackSetCheckpoint)
# The parameters are inferred
torch_callback$param_set
# Retrieve a torch callback from the dictionary
torch_callback = t_clbk("checkpoint",
path = tempfile(), freq = 1
)
torch_callback
torch_callback$label
torch_callback$id
# open the help page of the wrapped callback set
# torch_callback$help()
# Create the callback set
callback = torch_callback$generate()
callback
# is the same as
CallbackSetCheckpoint$new(
path = tempfile(), freq = 1
)
# Use in a learner
learner = lrn("regr.mlp", callbacks = t_clbk("checkpoint"))
# the parameters of the callback are added to the learner's parameter set
learner$param_set