mlr_pipeops_torch_model {mlr3torch}R Documentation

PipeOp Torch Model

Description

Builds a Torch Learner from a ModelDescriptor and trains it with the given parameter specification. The task type must be specified during construction.

Input and Output Channels

There is one input channel "input" that takes in ModelDescriptor during traing and a Task of the specified task_type during prediction. The output is NULL during training and a Prediction of given task_type during prediction.

State

A trained LearnerTorchModel.

Parameters

General:

The parameters of the optimizer, loss and callbacks, prefixed with "opt.", "loss." and "cb.<callback id>." respectively, as well as:

Evaluation:

Early Stopping:

Dataloader:

Also see torch::dataloder for more information.

Internals

A LearnerTorchModel is created by calling model_descriptor_to_learner() on the provided ModelDescriptor that is received through the input channel. Then the parameters are set according to the parameters specified in PipeOpTorchModel and its '$train()⁠ method is called on the [⁠Task⁠][mlr3::Task] stored in the [⁠ModelDescriptor'].

Super classes

mlr3pipelines::PipeOp -> mlr3pipelines::PipeOpLearner -> PipeOpTorchModel

Methods

Public methods

Inherited methods

Method new()

Creates a new instance of this R6 class.

Usage
PipeOpTorchModel$new(task_type, id = "torch_model", param_vals = list())
Arguments
task_type

(character(1))
The task type of the model.

id

(character(1))
Identifier of the resulting object.

param_vals

(list())
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.


Method clone()

The objects of this class are cloneable with this method.

Usage
PipeOpTorchModel$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

See Also

Other PipeOps: mlr_pipeops_nn_avg_pool1d, mlr_pipeops_nn_avg_pool2d, mlr_pipeops_nn_avg_pool3d, mlr_pipeops_nn_batch_norm1d, mlr_pipeops_nn_batch_norm2d, mlr_pipeops_nn_batch_norm3d, mlr_pipeops_nn_block, mlr_pipeops_nn_celu, mlr_pipeops_nn_conv1d, mlr_pipeops_nn_conv2d, mlr_pipeops_nn_conv3d, mlr_pipeops_nn_conv_transpose1d, mlr_pipeops_nn_conv_transpose2d, mlr_pipeops_nn_conv_transpose3d, mlr_pipeops_nn_dropout, mlr_pipeops_nn_elu, mlr_pipeops_nn_flatten, mlr_pipeops_nn_gelu, mlr_pipeops_nn_glu, mlr_pipeops_nn_hardshrink, mlr_pipeops_nn_hardsigmoid, mlr_pipeops_nn_hardtanh, mlr_pipeops_nn_head, mlr_pipeops_nn_layer_norm, mlr_pipeops_nn_leaky_relu, mlr_pipeops_nn_linear, mlr_pipeops_nn_log_sigmoid, mlr_pipeops_nn_max_pool1d, mlr_pipeops_nn_max_pool2d, mlr_pipeops_nn_max_pool3d, mlr_pipeops_nn_merge, mlr_pipeops_nn_merge_cat, mlr_pipeops_nn_merge_prod, mlr_pipeops_nn_merge_sum, mlr_pipeops_nn_prelu, mlr_pipeops_nn_relu, mlr_pipeops_nn_relu6, mlr_pipeops_nn_reshape, mlr_pipeops_nn_rrelu, mlr_pipeops_nn_selu, mlr_pipeops_nn_sigmoid, mlr_pipeops_nn_softmax, mlr_pipeops_nn_softplus, mlr_pipeops_nn_softshrink, mlr_pipeops_nn_softsign, mlr_pipeops_nn_squeeze, mlr_pipeops_nn_tanh, mlr_pipeops_nn_tanhshrink, mlr_pipeops_nn_threshold, mlr_pipeops_torch_ingress, mlr_pipeops_torch_ingress_categ, mlr_pipeops_torch_ingress_ltnsr, mlr_pipeops_torch_ingress_num, mlr_pipeops_torch_loss, mlr_pipeops_torch_model_classif, mlr_pipeops_torch_model_regr


[Package mlr3torch version 0.1.0 Index]