mlr_learners.tab_resnet {mlr3torch} | R Documentation |
Tabular ResNet
Description
Tabular resnet.
Dictionary
This Learner can be instantiated using the sugar function lrn()
:
lrn("classif.tab_resnet", ...) lrn("regr.tab_resnet", ...)
Properties
Supported task types: 'classif', 'regr'
Predict Types:
classif: 'response', 'prob'
regr: 'response'
Feature Types: “integer”, “numeric”
Parameters
Parameters from LearnerTorch
, as well as:
-
n_blocks
::integer(1)
The number of blocks. -
d_block
::integer(1)
The input and output dimension of a block. -
d_hidden
::integer(1)
The latent dimension of a block. -
d_hidden_multiplier
::integer(1)
Alternative way to specify the latent dimension asd_block * d_hidden_multiplier
. -
dropout1
::numeric(1)
First dropout ratio. -
dropout2
::numeric(1)
Second dropout ratio.
Super classes
mlr3::Learner
-> mlr3torch::LearnerTorch
-> LearnerTorchTabResNet
Methods
Public methods
Inherited methods
mlr3::Learner$base_learner()
mlr3::Learner$help()
mlr3::Learner$predict()
mlr3::Learner$predict_newdata()
mlr3::Learner$reset()
mlr3::Learner$train()
mlr3torch::LearnerTorch$dataset()
mlr3torch::LearnerTorch$format()
mlr3torch::LearnerTorch$marshal()
mlr3torch::LearnerTorch$print()
mlr3torch::LearnerTorch$unmarshal()
Method new()
Creates a new instance of this R6 class.
Usage
LearnerTorchTabResNet$new( task_type, optimizer = NULL, loss = NULL, callbacks = list() )
Arguments
task_type
(
character(1)
)
The task type, either"classif
" or"regr"
.optimizer
(
TorchOptimizer
)
The optimizer to use for training. Per default, adam is used.loss
(
TorchLoss
)
The loss used to train the network. Per default, mse is used for regression and cross_entropy for classification.callbacks
(
list()
ofTorchCallback
s)
The callbacks. Must have unique ids.
Method clone()
The objects of this class are cloneable with this method.
Usage
LearnerTorchTabResNet$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
References
Gorishniy Y, Rubachev I, Khrulkov V, Babenko A (2021). “Revisiting Deep Learning for Tabular Data.” arXiv, 2106.11959.
See Also
Other Learner:
mlr_learners.mlp
,
mlr_learners.torch_featureless
,
mlr_learners_torch
,
mlr_learners_torch_image
,
mlr_learners_torch_model
Examples
# Define the Learner and set parameter values
learner = lrn("classif.tab_resnet")
learner$param_set$set_values(
epochs = 1, batch_size = 16, device = "cpu",
n_blocks = 2, d_block = 10, d_hidden = 20, dropout1 = 0.3, dropout2 = 0.3
)
# Define a Task
task = tsk("iris")
# Create train and test set
ids = partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = ids$train)
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)
# Score the predictions
predictions$score()