torch_callback {mlr3torch} | R Documentation |
Create a Callback Desctiptor
Description
Convenience function to create a custom TorchCallback
.
All arguments that are available in callback_set()
are also available here.
For more information on how to correctly implement a new callback, see CallbackSet
.
Usage
torch_callback(
id,
classname = paste0("CallbackSet", capitalize(id)),
param_set = NULL,
packages = NULL,
label = capitalize(id),
man = NULL,
on_begin = NULL,
on_end = NULL,
on_exit = NULL,
on_epoch_begin = NULL,
on_before_valid = NULL,
on_epoch_end = NULL,
on_batch_begin = NULL,
on_batch_end = NULL,
on_after_backward = NULL,
on_batch_valid_begin = NULL,
on_batch_valid_end = NULL,
on_valid_end = NULL,
state_dict = NULL,
load_state_dict = NULL,
initialize = NULL,
public = NULL,
private = NULL,
active = NULL,
parent_env = parent.frame(),
inherit = CallbackSet,
lock_objects = FALSE
)
Arguments
id |
( |
classname |
( |
param_set |
( |
packages |
( |
label |
( |
man |
( |
on_begin , on_end , on_epoch_begin , on_before_valid , on_epoch_end , on_batch_begin , on_batch_end , on_after_backward , on_batch_valid_begin , on_batch_valid_end , on_valid_end , on_exit |
( |
state_dict |
( |
load_state_dict |
( |
initialize |
( |
public , private , active |
( |
parent_env |
( |
inherit |
( |
lock_objects |
( |
Value
Internals
It first creates an R6
class inheriting from CallbackSet
(using callback_set()
) and
then wraps this generator in a TorchCallback
that can be passed to a torch learner.
Stages
-
begin
:: Run before the training loop begins. -
epoch_begin
:: Run he beginning of each epoch. -
batch_begin
:: Run before the forward call. -
after_backward
:: Run after the backward call. -
batch_end
:: Run after the optimizer step. -
batch_valid_begin
:: Run before the forward call in the validation loop. -
batch_valid_end
:: Run after the forward call in the validation loop. -
valid_end
:: Run at the end of validation. -
epoch_end
:: Run at the end of each epoch. -
end
:: Run after last epoch. -
exit
:: Run at last, usingon.exit()
.
See Also
Other Callback:
TorchCallback
,
as_torch_callback()
,
as_torch_callbacks()
,
callback_set()
,
mlr3torch_callbacks
,
mlr_callback_set
,
mlr_callback_set.checkpoint
,
mlr_callback_set.progress
,
mlr_context_torch
,
t_clbk()
Examples
custom_tcb = torch_callback("custom",
initialize = function(name) {
self$name = name
},
on_begin = function() {
cat("Hello", self$name, ", we will train for ", self$ctx$total_epochs, "epochs.\n")
},
on_end = function() {
cat("Training is done.")
}
)
learner = lrn("classif.torch_featureless",
batch_size = 16,
epochs = 1,
callbacks = custom_tcb,
cb.custom.name = "Marie",
device = "cpu"
)
task = tsk("iris")
learner$train(task)