mlr_learners_torch_model {mlr3torch} | R Documentation |
Learner Torch Model
Description
Create a torch learner from an instantiated nn_module()
.
For classification, the output of the network must be the scores (before the softmax).
Parameters
See LearnerTorch
Super classes
mlr3::Learner
-> mlr3torch::LearnerTorch
-> LearnerTorchModel
Active bindings
network_stored
(
nn_module
orNULL
)
The network that will be trained. After calling$train()
, this isNULL
.ingress_tokens
(named
list()
withTorchIngressToken
orNULL
)
The ingress tokens. Must be non-NULL
when calling$train()
.
Methods
Public methods
Inherited methods
mlr3::Learner$base_learner()
mlr3::Learner$help()
mlr3::Learner$predict()
mlr3::Learner$predict_newdata()
mlr3::Learner$reset()
mlr3::Learner$train()
mlr3torch::LearnerTorch$dataset()
mlr3torch::LearnerTorch$format()
mlr3torch::LearnerTorch$marshal()
mlr3torch::LearnerTorch$print()
mlr3torch::LearnerTorch$unmarshal()
Method new()
Creates a new instance of this R6 class.
Usage
LearnerTorchModel$new( network = NULL, ingress_tokens = NULL, task_type, properties = NULL, optimizer = NULL, loss = NULL, callbacks = list(), packages = character(0), feature_types = NULL )
Arguments
network
(
nn_module
)
An instantiatednn_module
. Is not cloned during construction. For classification, outputs must be the scores (before the softmax).ingress_tokens
(
list
ofTorchIngressToken()
)
A list with ingress tokens that defines how the dataloader will be defined.task_type
(
character(1)
)
The task type.properties
(
NULL
orcharacter()
)
The properties of the learner. Defaults to all available properties for the given task type.optimizer
(
TorchOptimizer
)
The torch optimizer.loss
(
TorchLoss
)
The loss to use for training.callbacks
(
list()
ofTorchCallback
s)
The callbacks used during training. Must have unique ids. They are executed in the order in which they are providedpackages
(
character()
)
The R packages this object depends on.feature_types
(
NULL
orcharacter()
)
The feature types. Defaults to all available feature types.
Method clone()
The objects of this class are cloneable with this method.
Usage
LearnerTorchModel$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
See Also
Other Learner:
mlr_learners.mlp
,
mlr_learners.tab_resnet
,
mlr_learners.torch_featureless
,
mlr_learners_torch
,
mlr_learners_torch_image
Other Graph Network:
ModelDescriptor()
,
TorchIngressToken()
,
mlr_pipeops_module
,
mlr_pipeops_torch
,
mlr_pipeops_torch_ingress
,
mlr_pipeops_torch_ingress_categ
,
mlr_pipeops_torch_ingress_ltnsr
,
mlr_pipeops_torch_ingress_num
,
model_descriptor_to_learner()
,
model_descriptor_to_module()
,
model_descriptor_union()
,
nn_graph()
Examples
# We show the learner using a classification task
# The iris task has 4 features and 3 classes
network = nn_linear(4, 3)
task = tsk("iris")
# This defines the dataloader.
# It loads all 4 features, which are also numeric.
# The shape is (NA, 4) because the batch dimension is generally NA
ingress_tokens = list(
input = TorchIngressToken(task$feature_names, batchgetter_num, c(NA, 4))
)
# Creating the learner and setting required parameters
learner = lrn("classif.torch_model",
network = network,
ingress_tokens = ingress_tokens,
batch_size = 16,
epochs = 1,
device = "cpu"
)
# A simple train-predict
ids = partition(task)
learner$train(task, ids$train)
learner$predict(task, ids$test)