| mlr_learners_torch {mlr3torch} | R Documentation |
Base Class for Torch Learners
Description
This base class provides the basic functionality for training and prediction of a neural network. All torch learners should inherit from this class.
Validation
To specify the validation data, you can set the $validate field of the Learner, which can be set to:
-
NULL: no validation -
ratio: only proportion1 - ratioof the task is used for training andratiois used for validation. -
"test"means that the"test"task of a resampling is used and is not possible when calling$train()manually. -
"predefined": This will use the predefined$internal_valid_taskof amlr3::Task, which can e.g. be created using the$divide()method ofTask.
This validation data can also be used for early stopping, see the description of the Learner's parameters.
Saving a Learner
In order to save a LearnerTorch for later usage, it is necessary to call the $marshal() method on the Learner
before writing it to disk, as the object will otherwise not be saved correctly.
After loading a marshaled LearnerTorch into R again, you then need to call $unmarshal() to transform it
into a useable state.
Early Stopping and Tuning
In order to prevent overfitting, the LearnerTorch class allows to use early stopping via the patience
and min_delta parameters, see the Learner's parameters.
When tuning a LearnerTorch it is also possible to combine the explicit tuning via mlr3tuning
and the LearnerTorch's internal tuning of the epochs via early stopping.
To do so, you just need to include epochs = to_tune(upper = <upper>, internal = TRUE) in the search space,
where <upper> is the maximally allowed number of epochs, and configure the early stopping.
Model
The Model is a list of class "learner_torch_model" with the following elements:
-
network:: The trained network. -
optimizer:: The$state_dict()optimizer used to train the network. -
loss_fn:: The$state_dict()of the loss used to train the network. -
callbacks:: The callbacks used to train the network. -
seed:: The seed that was / is used for training and prediction. -
epochs:: How many epochs the model was trained for (early stopping). -
task_col_info:: Adata.table()containing information about the train-task.
Parameters
General:
The parameters of the optimizer, loss and callbacks,
prefixed with "opt.", "loss." and "cb.<callback id>." respectively, as well as:
-
epochs::integer(1)
The number of epochs. -
device::character(1)
The device. One of"auto","cpu", or"cuda"or other values defined inmlr_reflections$torch$devices. The value is initialized to"auto", which will select"cuda"if possible, then try"mps"and otherwise fall back to"cpu". -
num_threads::integer(1)
The number of threads for intraop pararallelization (ifdeviceis"cpu"). This value is initialized to 1. -
seed::integer(1)or"random"
The seed that is used during training and prediction. This value is initialized to"random", which means that a random seed will be sampled at the beginning of the training phase. This seed (either set or randomly sampled) is available via$model$seedafter training and used during prediction. Note that by setting the seed during the training phase this will mean that by default (i.e. whenseedis"random"), clones of the learner will use a different seed.
Evaluation:
-
measures_train::Measureorlist()ofMeasures.
Measures to be evaluated during training. -
measures_valid::Measureorlist()ofMeasures.
Measures to be evaluated during validation. -
eval_freq::integer(1)
How often the train / validation predictions are evaluated usingmeasures_train/measures_valid. This is initialized to1. Note that the final model is always evaluated.
Early Stopping:
-
patience::integer(1)
This activates early stopping using the validation scores. If the performance of a model does not improve forpatienceevaluation steps, training is ended. Note that the final model is stored in the learner, not the best model. This is initialized to0, which means no early stopping. The first entry frommeasures_validis used as the metric. This also requires to specify the$validatefield of the Learner, as well asmeasures_valid. -
min_delta::double(1)
The minimum improvement threshold (>) for early stopping. Is initialized to 0.
Dataloader:
-
batch_size::integer(1)
The batch size (required). -
shuffle::logical(1)
Whether to shuffle the instances in the dataset. Default isFALSE. This does not impact validation. -
sampler::torch::sampler
Object that defines how the dataloader draw samples. -
batch_sampler::torch::sampler
Object that defines how the dataloader draws batches. -
num_workers::integer(1)
The number of workers for data loading (batches are loaded in parallel). The default is0, which means that data will be loaded in the main process. -
collate_fn::function
How to merge a list of samples to form a batch. -
pin_memory::logical(1)
Whether the dataloader copies tensors into CUDA pinned memory before returning them. -
drop_last::logical(1)
Whether to drop the last training batch in each epoch during training. Default isFALSE. -
timeout::numeric(1)
The timeout value for collecting a batch from workers. Negative values mean no timeout and the default is-1. -
worker_init_fn::function(id)
A function that receives the worker id (in[1, num_workers]) and is exectued after seeding on the worker but before data loading. -
worker_globals::list()|character()
When loading data in parallel, this allows to export globals to the workers. If this is a character vector, the objects in the global environment with those names are copied to the workers. -
worker_packages::character()
Which packages to load on the workers.
Also see torch::dataloder for more information.
Inheriting
There are no seperate classes for classification and regression to inherit from.
Instead, the task_type must be specified as a construction argument.
Currently, only classification and regression are supported.
When inheriting from this class, one should overload two private methods:
-
.network(task, param_vals)
(Task,list()) ->nn_module
Construct atorch::nn_moduleobject for the given task and parameter values, i.e. the neural network that is trained by the learner. For classification, the output of this network are expected to be the scores before the application of the final softmax layer. -
.dataset(task, param_vals)
(Task,list()) ->torch::dataset
Create the dataset for the task. Must respect the parameter value of the device. Moreover, one needs to pay attention respect the row ids of the provided task.
It is also possible to overwrite the private .dataloader() method instead of the .dataset() method.
Per default, a dataloader is constructed using the output from the .dataset() method.
However, this should respect the dataloader parameters from the ParamSet.
-
.dataloader(task, param_vals)
(Task,list()) ->torch::dataloader
Create a dataloader from the task. Needs to respect at leastbatch_sizeandshuffle(otherwise predictions can be permuted).
To change the predict types, the private .encode_prediction() method can be overwritten:
-
.encode_prediction(predict_tensor, task, param_vals)
(torch_tensor,Task,list()) ->list()
Take in the raw predictions fromself$network(predict_tensor) and encode them into a format that can be converted to validmlr3predictions usingmlr3::as_prediction_data(). This method must takeself$predict_typeinto account.
While it is possible to add parameters by specifying the param_set construction argument, it is currently
not possible to remove existing parameters, i.e. those listed in section Parameters.
None of the parameters provided in param_set can have an id that starts with "loss.", "opt.", or "cb."', as these are preserved for the dynamically constructed parameters of the optimizer, the loss function,
and the callbacks.
To perform additional input checks on the task, the private .verify_train_task(task, param_vals) and
.verify_predict_task(task, param_vals) can be overwritten.
For learners that have other construction arguments that should change the hash of a learner, it is required
to implement the private $.additional_phash_input().
Super class
mlr3::Learner -> LearnerTorch
Active bindings
validateHow to construct the internal validation data. This parameter can be either
NULL, a ratio in $(0, 1)$,"test", or"predefined".loss(
TorchLoss)
The torch loss.optimizer(
TorchOptimizer)
The torch optimizer.callbacks(
list()ofTorchCallbacks)
List of torch callbacks. The ids will be set as the names.internal_valid_scoresRetrieves the internal validation scores as a named
list(). Specify the$validatefield and themeasures_validparameter to configure this. ReturnsNULLif learner is not trained yet.internal_tuned_valuesWhen early stopping is activate, this returns a named list with the early-stopped epochs, otherwise an empty list is returned. Returns
NULLif learner is not trained yet.marshaled(
logical(1))
Whether the learner is marshaled.network(
nn_module())
Shortcut forlearner$model$network.param_set(
ParamSet)
The parameter sethash(
character(1))
Hash (unique identifier) for this object.phash(
character(1))
Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values).
Methods
Public methods
Inherited methods
Method new()
Creates a new instance of this R6 class.
Usage
LearnerTorch$new( id, task_type, param_set, properties, man, label, feature_types, optimizer = NULL, loss = NULL, packages = character(), predict_types = NULL, callbacks = list() )
Arguments
id(
character(1))
The id for of the new object.task_type(
character(1))
The task type.param_set(
ParamSetoralist())
Either a parameter set, or analist()containing different values of self, e.g.alist(private$.param_set1, private$.param_set2), from which aParamSetcollection should be created.properties(
character())
The properties of the object. Seemlr_reflections$learner_propertiesfor available values.man(
character(1))
String in the format[pkg]::[topic]pointing to a manual page for this object. The referenced help package can be opened via method$help().label(
character(1))
Label for the new instance.feature_types(
character())
The feature types. Seemlr_reflections$task_feature_typesfor available values, Additionally,"lazy_tensor"is supported.optimizer(
NULLorTorchOptimizer)
The optimizer to use for training. Defaults to adam.loss(
NULLorTorchLoss)
The loss to use for training. Defaults to MSE for regression and cross entropy for classification.packages(
character())
The R packages this object depends on.predict_types(
character())
The predict types. Seemlr_reflections$learner_predict_typesfor available values. For regression, the default is"response". For classification, this defaults to"response"and"prob". To deviate from the defaults, it is necessary to overwrite the private$.encode_prediction()method, see section Inheriting.callbacks(
list()ofTorchCallbacks)
The callbacks to use for training. Defaults to an emptylist(), i.e. no callbacks.
Method format()
Helper for print outputs.
Usage
LearnerTorch$format(...)
Arguments
...(ignored).
Method print()
Prints the object.
Usage
LearnerTorch$print(...)
Arguments
...(any)
Currently unused.
Method marshal()
Marshal the learner.
Usage
LearnerTorch$marshal(...)
Arguments
...(any)
Additional parameters.
Returns
self
Method unmarshal()
Unmarshal the learner.
Usage
LearnerTorch$unmarshal(...)
Arguments
...(any)
Additional parameters.
Returns
self
Method dataset()
Create the dataset for a task.
Usage
LearnerTorch$dataset(task)
Arguments
taskTask
The task
Returns
Method clone()
The objects of this class are cloneable with this method.
Usage
LearnerTorch$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
See Also
Other Learner:
mlr_learners.mlp,
mlr_learners.tab_resnet,
mlr_learners.torch_featureless,
mlr_learners_torch_image,
mlr_learners_torch_model