mlr_context_torch {mlr3torch} | R Documentation |
Context for Torch Learner
Description
Context for training a torch learner.
This is the - mostly read-only - information callbacks have access to through the argument ctx
.
For more information on callbacks, see CallbackSet
.
Public fields
learner
(
Learner
)
The torch learner.task_train
(
Task
)
The training task.task_valid
(
Task
orNULL
)
The validation task.loader_train
(
torch::dataloader
)
The data loader for training.loader_valid
(
torch::dataloader
)
The data loader for validation.measures_train
(
list()
ofMeasure
s)
Measures used for training.measures_valid
(
list()
ofMeasure
s)
Measures used for validation.network
(
torch::nn_module
)
The torch network.optimizer
(
torch::optimizer
)
The optimizer.loss_fn
(
torch::nn_module
)
The loss function.total_epochs
(
integer(1)
)
The total number of epochs the learner is trained for.last_scores_train
(named
list()
orNULL
)
The scores from the last training batch. Names are the ids of the training measures. IfLearnerTorch
setseval_freq
different from1
, this isNULL
in all epochs that don't evaluate the model.last_scores_valid
(
list()
)
The scores from the last validation batch. Names are the ids of the validation measures. IfLearnerTorch
setseval_freq
different from1
, this isNULL
in all epochs that don't evaluate the model.epoch
(
integer(1)
)
The current epoch.step
(
integer(1)
)
The current iteration.prediction_encoder
(
function()
)
The learner's prediction encoder.batch
(named
list()
oftorch_tensor
s)
The current batch.terminate
(
logical(1)
)
If this field is set toTRUE
at the end of an epoch, training stops.
Methods
Public methods
Method new()
Creates a new instance of this R6 class.
Usage
ContextTorch$new( learner, task_train, task_valid = NULL, loader_train, loader_valid = NULL, measures_train = NULL, measures_valid = NULL, network, optimizer, loss_fn, total_epochs, prediction_encoder, eval_freq = 1L )
Arguments
learner
(
Learner
)
The torch learner.task_train
(
Task
)
The training task.task_valid
(
Task
orNULL
)
The validation task.loader_train
(
torch::dataloader
)
The data loader for training.loader_valid
(
torch::dataloader
orNULL
)
The data loader for validation.measures_train
(
list()
ofMeasure
s orNULL
)
Measures used for training. Default isNULL
.measures_valid
(
list()
ofMeasure
s orNULL
)
Measures used for validation.network
(
torch::nn_module
)
The torch network.optimizer
(
torch::optimizer
)
The optimizer.loss_fn
(
torch::nn_module
)
The loss function.total_epochs
(
integer(1)
)
The total number of epochs the learner is trained for.prediction_encoder
(
function()
)
The learner's prediction encoder.eval_freq
(
integer(1)
)
The evaluation frequency.
Method clone()
The objects of this class are cloneable with this method.
Usage
ContextTorch$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
See Also
Other Callback:
TorchCallback
,
as_torch_callback()
,
as_torch_callbacks()
,
callback_set()
,
mlr3torch_callbacks
,
mlr_callback_set
,
mlr_callback_set.checkpoint
,
mlr_callback_set.progress
,
t_clbk()
,
torch_callback()