mlr_learners_torch {mlr3torch}R Documentation

Base Class for Torch Learners

Description

This base class provides the basic functionality for training and prediction of a neural network. All torch learners should inherit from this class.

Validation

To specify the validation data, you can set the ⁠$validate⁠ field of the Learner, which can be set to:

This validation data can also be used for early stopping, see the description of the Learner's parameters.

Saving a Learner

In order to save a LearnerTorch for later usage, it is necessary to call the ⁠$marshal()⁠ method on the Learner before writing it to disk, as the object will otherwise not be saved correctly. After loading a marshaled LearnerTorch into R again, you then need to call ⁠$unmarshal()⁠ to transform it into a useable state.

Early Stopping and Tuning

In order to prevent overfitting, the LearnerTorch class allows to use early stopping via the patience and min_delta parameters, see the Learner's parameters. When tuning a LearnerTorch it is also possible to combine the explicit tuning via mlr3tuning and the LearnerTorch's internal tuning of the epochs via early stopping. To do so, you just need to include ⁠epochs = to_tune(upper = <upper>, internal = TRUE)⁠ in the search space, where ⁠<upper>⁠ is the maximally allowed number of epochs, and configure the early stopping.

Model

The Model is a list of class "learner_torch_model" with the following elements:

Parameters

General:

The parameters of the optimizer, loss and callbacks, prefixed with "opt.", "loss." and "cb.<callback id>." respectively, as well as:

Evaluation:

Early Stopping:

Dataloader:

Also see torch::dataloder for more information.

Inheriting

There are no seperate classes for classification and regression to inherit from. Instead, the task_type must be specified as a construction argument. Currently, only classification and regression are supported.

When inheriting from this class, one should overload two private methods:

It is also possible to overwrite the private .dataloader() method instead of the .dataset() method. Per default, a dataloader is constructed using the output from the .dataset() method. However, this should respect the dataloader parameters from the ParamSet.

To change the predict types, the private .encode_prediction() method can be overwritten:

While it is possible to add parameters by specifying the param_set construction argument, it is currently not possible to remove existing parameters, i.e. those listed in section Parameters. None of the parameters provided in param_set can have an id that starts with "loss.", ⁠"opt.", or ⁠"cb."', as these are preserved for the dynamically constructed parameters of the optimizer, the loss function, and the callbacks.

To perform additional input checks on the task, the private .verify_train_task(task, param_vals) and .verify_predict_task(task, param_vals) can be overwritten.

For learners that have other construction arguments that should change the hash of a learner, it is required to implement the private ⁠$.additional_phash_input()⁠.

Super class

mlr3::Learner -> LearnerTorch

Active bindings

validate

How to construct the internal validation data. This parameter can be either NULL, a ratio in $(0, 1)$, "test", or "predefined".

loss

(TorchLoss)
The torch loss.

optimizer

(TorchOptimizer)
The torch optimizer.

callbacks

(list() of TorchCallbacks)
List of torch callbacks. The ids will be set as the names.

internal_valid_scores

Retrieves the internal validation scores as a named list(). Specify the ⁠$validate⁠ field and the measures_valid parameter to configure this. Returns NULL if learner is not trained yet.

internal_tuned_values

When early stopping is activate, this returns a named list with the early-stopped epochs, otherwise an empty list is returned. Returns NULL if learner is not trained yet.

marshaled

(logical(1))
Whether the learner is marshaled.

network

(nn_module())
Shortcut for learner$model$network.

param_set

(ParamSet)
The parameter set

hash

(character(1))
Hash (unique identifier) for this object.

phash

(character(1))
Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values).

Methods

Public methods

Inherited methods

Method new()

Creates a new instance of this R6 class.

Usage
LearnerTorch$new(
  id,
  task_type,
  param_set,
  properties,
  man,
  label,
  feature_types,
  optimizer = NULL,
  loss = NULL,
  packages = character(),
  predict_types = NULL,
  callbacks = list()
)
Arguments
id

(character(1))
The id for of the new object.

task_type

(character(1))
The task type.

param_set

(ParamSet or alist())
Either a parameter set, or an alist() containing different values of self, e.g. alist(private$.param_set1, private$.param_set2), from which a ParamSet collection should be created.

properties

(character())
The properties of the object. See mlr_reflections$learner_properties for available values.

man

(character(1))
String in the format ⁠[pkg]::[topic]⁠ pointing to a manual page for this object. The referenced help package can be opened via method ⁠$help()⁠.

label

(character(1))
Label for the new instance.

feature_types

(character())
The feature types. See mlr_reflections$task_feature_types for available values, Additionally, "lazy_tensor" is supported.

optimizer

(NULL or TorchOptimizer)
The optimizer to use for training. Defaults to adam.

loss

(NULL or TorchLoss)
The loss to use for training. Defaults to MSE for regression and cross entropy for classification.

packages

(character())
The R packages this object depends on.

predict_types

(character())
The predict types. See mlr_reflections$learner_predict_types for available values. For regression, the default is "response". For classification, this defaults to "response" and "prob". To deviate from the defaults, it is necessary to overwrite the private ⁠$.encode_prediction()⁠ method, see section Inheriting.

callbacks

(list() of TorchCallbacks)
The callbacks to use for training. Defaults to an empty list(), i.e. no callbacks.


Method format()

Helper for print outputs.

Usage
LearnerTorch$format(...)
Arguments
...

(ignored).


Method print()

Prints the object.

Usage
LearnerTorch$print(...)
Arguments
...

(any)
Currently unused.


Method marshal()

Marshal the learner.

Usage
LearnerTorch$marshal(...)
Arguments
...

(any)
Additional parameters.

Returns

self


Method unmarshal()

Unmarshal the learner.

Usage
LearnerTorch$unmarshal(...)
Arguments
...

(any)
Additional parameters.

Returns

self


Method dataset()

Create the dataset for a task.

Usage
LearnerTorch$dataset(task)
Arguments
task

Task
The task

Returns

dataset


Method clone()

The objects of this class are cloneable with this method.

Usage
LearnerTorch$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

See Also

Other Learner: mlr_learners.mlp, mlr_learners.tab_resnet, mlr_learners.torch_featureless, mlr_learners_torch_image, mlr_learners_torch_model


[Package mlr3torch version 0.1.0 Index]