TorchCallback {mlr3torch}R Documentation

Torch Callback

Description

This wraps a CallbackSet and annotates it with metadata, most importantly a ParamSet. The callback is created for the given parameter values by calling the ⁠$generate()⁠ method.

This class is usually used to configure the callback of a torch learner, e.g. when constructing a learner of in a ModelDescriptor.

For a list of available callbacks, see mlr3torch_callbacks. To conveniently retrieve a TorchCallback, use t_clbk().

Parameters

Defined by the constructor argument param_set. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter for each argument of the wrapped loss function, where the parametes are then of type ParamUty.

Super class

mlr3torch::TorchDescriptor -> TorchCallback

Methods

Public methods

Inherited methods

Method new()

Creates a new instance of this R6 class.

Usage
TorchCallback$new(
  callback_generator,
  param_set = NULL,
  id = NULL,
  label = NULL,
  packages = NULL,
  man = NULL
)
Arguments
callback_generator

(R6ClassGenerator)
The class generator for the callback that is being wrapped.

param_set

(ParamSet or NULL)
The parameter set. If NULL (default) it is inferred from callback_generator.

id

(character(1))
The id for of the new object.

label

(character(1))
Label for the new instance.

packages

(character())
The R packages this object depends on.

man

(character(1))
String in the format ⁠[pkg]::[topic]⁠ pointing to a manual page for this object. The referenced help package can be opened via method ⁠$help()⁠.


Method clone()

The objects of this class are cloneable with this method.

Usage
TorchCallback$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

See Also

Other Callback: as_torch_callback(), as_torch_callbacks(), callback_set(), mlr3torch_callbacks, mlr_callback_set, mlr_callback_set.checkpoint, mlr_callback_set.progress, mlr_context_torch, t_clbk(), torch_callback()

Other Torch Descriptor: TorchDescriptor, TorchLoss, TorchOptimizer, as_torch_callbacks(), as_torch_loss(), as_torch_optimizer(), mlr3torch_losses, mlr3torch_optimizers, t_clbk(), t_loss(), t_opt()

Examples


# Create a new torch callback from an existing callback set
torch_callback = TorchCallback$new(CallbackSetCheckpoint)
# The parameters are inferred
torch_callback$param_set

# Retrieve a torch callback from the dictionary
torch_callback = t_clbk("checkpoint",
  path = tempfile(), freq = 1
)
torch_callback
torch_callback$label
torch_callback$id

# open the help page of the wrapped callback set
# torch_callback$help()

# Create the callback set
callback = torch_callback$generate()
callback
# is the same as
CallbackSetCheckpoint$new(
  path = tempfile(), freq = 1
)

# Use in a learner
learner = lrn("regr.mlp", callbacks = t_clbk("checkpoint"))
# the parameters of the callback are added to the learner's parameter set
learner$param_set


[Package mlr3torch version 0.1.0 Index]