| mlr_context_torch {mlr3torch} | R Documentation |
Context for Torch Learner
Description
Context for training a torch learner.
This is the - mostly read-only - information callbacks have access to through the argument ctx.
For more information on callbacks, see CallbackSet.
Public fields
learner(
Learner)
The torch learner.task_train(
Task)
The training task.task_valid(
TaskorNULL)
The validation task.loader_train(
torch::dataloader)
The data loader for training.loader_valid(
torch::dataloader)
The data loader for validation.measures_train(
list()ofMeasures)
Measures used for training.measures_valid(
list()ofMeasures)
Measures used for validation.network(
torch::nn_module)
The torch network.optimizer(
torch::optimizer)
The optimizer.loss_fn(
torch::nn_module)
The loss function.total_epochs(
integer(1))
The total number of epochs the learner is trained for.last_scores_train(named
list()orNULL)
The scores from the last training batch. Names are the ids of the training measures. IfLearnerTorchsetseval_freqdifferent from1, this isNULLin all epochs that don't evaluate the model.last_scores_valid(
list())
The scores from the last validation batch. Names are the ids of the validation measures. IfLearnerTorchsetseval_freqdifferent from1, this isNULLin all epochs that don't evaluate the model.epoch(
integer(1))
The current epoch.step(
integer(1))
The current iteration.prediction_encoder(
function())
The learner's prediction encoder.batch(named
list()oftorch_tensors)
The current batch.terminate(
logical(1))
If this field is set toTRUEat the end of an epoch, training stops.
Methods
Public methods
Method new()
Creates a new instance of this R6 class.
Usage
ContextTorch$new( learner, task_train, task_valid = NULL, loader_train, loader_valid = NULL, measures_train = NULL, measures_valid = NULL, network, optimizer, loss_fn, total_epochs, prediction_encoder, eval_freq = 1L )
Arguments
learner(
Learner)
The torch learner.task_train(
Task)
The training task.task_valid(
TaskorNULL)
The validation task.loader_train(
torch::dataloader)
The data loader for training.loader_valid(
torch::dataloaderorNULL)
The data loader for validation.measures_train(
list()ofMeasures orNULL)
Measures used for training. Default isNULL.measures_valid(
list()ofMeasures orNULL)
Measures used for validation.network(
torch::nn_module)
The torch network.optimizer(
torch::optimizer)
The optimizer.loss_fn(
torch::nn_module)
The loss function.total_epochs(
integer(1))
The total number of epochs the learner is trained for.prediction_encoder(
function())
The learner's prediction encoder.eval_freq(
integer(1))
The evaluation frequency.
Method clone()
The objects of this class are cloneable with this method.
Usage
ContextTorch$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
See Also
Other Callback:
TorchCallback,
as_torch_callback(),
as_torch_callbacks(),
callback_set(),
mlr3torch_callbacks,
mlr_callback_set,
mlr_callback_set.checkpoint,
mlr_callback_set.progress,
t_clbk(),
torch_callback()