torch_callback {mlr3torch}R Documentation

Create a Callback Desctiptor

Description

Convenience function to create a custom TorchCallback. All arguments that are available in callback_set() are also available here. For more information on how to correctly implement a new callback, see CallbackSet.

Usage

torch_callback(
  id,
  classname = paste0("CallbackSet", capitalize(id)),
  param_set = NULL,
  packages = NULL,
  label = capitalize(id),
  man = NULL,
  on_begin = NULL,
  on_end = NULL,
  on_exit = NULL,
  on_epoch_begin = NULL,
  on_before_valid = NULL,
  on_epoch_end = NULL,
  on_batch_begin = NULL,
  on_batch_end = NULL,
  on_after_backward = NULL,
  on_batch_valid_begin = NULL,
  on_batch_valid_end = NULL,
  on_valid_end = NULL,
  state_dict = NULL,
  load_state_dict = NULL,
  initialize = NULL,
  public = NULL,
  private = NULL,
  active = NULL,
  parent_env = parent.frame(),
  inherit = CallbackSet,
  lock_objects = FALSE
)

Arguments

id

(character(1))
'
The id for the torch callback.

classname

(character(1))
The class name.

param_set

(ParamSet)
The parameter set, if not present it is inferred from the ⁠$initialize()⁠ method.

packages

(character())
⁠The packages the callback depends on. Default is⁠NULL'.

label

(character(1))
The label for the torch callback. Defaults to the capitalized id.

man

(character(1))
String in the format ⁠[pkg]::[topic]⁠ pointing to a manual page for this object. The referenced help package can be opened via method ⁠$help()⁠. The default is NULL.

on_begin, on_end, on_epoch_begin, on_before_valid, on_epoch_end, on_batch_begin, on_batch_end, on_after_backward, on_batch_valid_begin, on_batch_valid_end, on_valid_end, on_exit

(function)
Function to execute at the given stage, see section Stages.

state_dict

(⁠function()⁠)
The function that retrieves the state dict from the callback. This is what will be available in the learner after training.

load_state_dict

(⁠function(state_dict)⁠)
Function that loads a callback state.

initialize

(⁠function()⁠)
The initialization method of the callback.

public, private, active

(list())
Additional public, private, and active fields to add to the callback.

parent_env

(environment())
The parent environment for the R6Class.

inherit

(R6ClassGenerator)
From which class to inherit. This class must either be CallbackSet (default) or inherit from it.

lock_objects

(logical(1))
Whether to lock the objects of the resulting R6Class. If FALSE (default), values can be freely assigned to self without declaring them in the class definition.

Value

TorchCallback

Internals

It first creates an R6 class inheriting from CallbackSet (using callback_set()) and then wraps this generator in a TorchCallback that can be passed to a torch learner.

Stages

See Also

Other Callback: TorchCallback, as_torch_callback(), as_torch_callbacks(), callback_set(), mlr3torch_callbacks, mlr_callback_set, mlr_callback_set.checkpoint, mlr_callback_set.progress, mlr_context_torch, t_clbk()

Examples


custom_tcb = torch_callback("custom",
  initialize = function(name) {
    self$name = name
  },
  on_begin = function() {
    cat("Hello", self$name, ", we will train for ", self$ctx$total_epochs, "epochs.\n")
  },
  on_end = function() {
    cat("Training is done.")
  }
)

learner = lrn("classif.torch_featureless",
  batch_size = 16,
  epochs = 1,
  callbacks = custom_tcb,
  cb.custom.name = "Marie",
  device = "cpu"
)
task = tsk("iris")
learner$train(task)


[Package mlr3torch version 0.1.0 Index]