callback_set {mlr3torch}R Documentation

Create a Set of Callbacks for Torch

Description

Creates an R6ClassGenerator inheriting from CallbackSet. Additionally performs checks such as that the stages are not accidentally misspelled. To create a TorchCallback use torch_callback().

In order for the resulting class to be cloneable, the private method ⁠$deep_clone()⁠ must be provided.

Usage

callback_set(
  classname,
  on_begin = NULL,
  on_end = NULL,
  on_exit = NULL,
  on_epoch_begin = NULL,
  on_before_valid = NULL,
  on_epoch_end = NULL,
  on_batch_begin = NULL,
  on_batch_end = NULL,
  on_after_backward = NULL,
  on_batch_valid_begin = NULL,
  on_batch_valid_end = NULL,
  on_valid_end = NULL,
  state_dict = NULL,
  load_state_dict = NULL,
  initialize = NULL,
  public = NULL,
  private = NULL,
  active = NULL,
  parent_env = parent.frame(),
  inherit = CallbackSet,
  lock_objects = FALSE
)

Arguments

classname

(character(1))
The class name.

on_begin, on_end, on_epoch_begin, on_before_valid, on_epoch_end, on_batch_begin, on_batch_end, on_after_backward, on_batch_valid_begin, on_batch_valid_end, on_valid_end, on_exit

(function)
Function to execute at the given stage, see section Stages.

state_dict

(⁠function()⁠)
The function that retrieves the state dict from the callback. This is what will be available in the learner after training.

load_state_dict

(⁠function(state_dict)⁠)
Function that loads a callback state.

initialize

(⁠function()⁠)
The initialization method of the callback.

public, private, active

(list())
Additional public, private, and active fields to add to the callback.

parent_env

(environment())
The parent environment for the R6Class.

inherit

(R6ClassGenerator)
From which class to inherit. This class must either be CallbackSet (default) or inherit from it.

lock_objects

(logical(1))
Whether to lock the objects of the resulting R6Class. If FALSE (default), values can be freely assigned to self without declaring them in the class definition.

Value

CallbackSet

See Also

Other Callback: TorchCallback, as_torch_callback(), as_torch_callbacks(), mlr3torch_callbacks, mlr_callback_set, mlr_callback_set.checkpoint, mlr_callback_set.progress, mlr_context_torch, t_clbk(), torch_callback()


[Package mlr3torch version 0.1.0 Index]