| mlr_context_torch {mlr3torch} | R Documentation | 
Context for Torch Learner
Description
Context for training a torch learner.
This is the - mostly read-only - information callbacks have access to through the argument ctx.
For more information on callbacks, see CallbackSet.
Public fields
- learner
- ( - Learner)
 The torch learner.
- task_train
- ( - Task)
 The training task.
- task_valid
- ( - Taskor- NULL)
 The validation task.
- loader_train
- ( - torch::dataloader)
 The data loader for training.
- loader_valid
- ( - torch::dataloader)
 The data loader for validation.
- measures_train
- ( - list()of- Measures)
 Measures used for training.
- measures_valid
- ( - list()of- Measures)
 Measures used for validation.
- network
- ( - torch::nn_module)
 The torch network.
- optimizer
- ( - torch::optimizer)
 The optimizer.
- loss_fn
- ( - torch::nn_module)
 The loss function.
- total_epochs
- ( - integer(1))
 The total number of epochs the learner is trained for.
- last_scores_train
- (named - list()or- NULL)
 The scores from the last training batch. Names are the ids of the training measures. If- LearnerTorchsets- eval_freqdifferent from- 1, this is- NULLin all epochs that don't evaluate the model.
- last_scores_valid
- ( - list())
 The scores from the last validation batch. Names are the ids of the validation measures. If- LearnerTorchsets- eval_freqdifferent from- 1, this is- NULLin all epochs that don't evaluate the model.
- epoch
- ( - integer(1))
 The current epoch.
- step
- ( - integer(1))
 The current iteration.
- prediction_encoder
- ( - function())
 The learner's prediction encoder.
- batch
- (named - list()of- torch_tensors)
 The current batch.
- terminate
- ( - logical(1))
 If this field is set to- TRUEat the end of an epoch, training stops.
Methods
Public methods
Method new()
Creates a new instance of this R6 class.
Usage
ContextTorch$new( learner, task_train, task_valid = NULL, loader_train, loader_valid = NULL, measures_train = NULL, measures_valid = NULL, network, optimizer, loss_fn, total_epochs, prediction_encoder, eval_freq = 1L )
Arguments
- learner
- ( - Learner)
 The torch learner.
- task_train
- ( - Task)
 The training task.
- task_valid
- ( - Taskor- NULL)
 The validation task.
- loader_train
- ( - torch::dataloader)
 The data loader for training.
- loader_valid
- ( - torch::dataloaderor- NULL)
 The data loader for validation.
- measures_train
- ( - list()of- Measures or- NULL)
 Measures used for training. Default is- NULL.
- measures_valid
- ( - list()of- Measures or- NULL)
 Measures used for validation.
- network
- ( - torch::nn_module)
 The torch network.
- optimizer
- ( - torch::optimizer)
 The optimizer.
- loss_fn
- ( - torch::nn_module)
 The loss function.
- total_epochs
- ( - integer(1))
 The total number of epochs the learner is trained for.
- prediction_encoder
- ( - function())
 The learner's prediction encoder.
- eval_freq
- ( - integer(1))
 The evaluation frequency.
Method clone()
The objects of this class are cloneable with this method.
Usage
ContextTorch$clone(deep = FALSE)
Arguments
- deep
- Whether to make a deep clone. 
See Also
Other Callback: 
TorchCallback,
as_torch_callback(),
as_torch_callbacks(),
callback_set(),
mlr3torch_callbacks,
mlr_callback_set,
mlr_callback_set.checkpoint,
mlr_callback_set.progress,
t_clbk(),
torch_callback()