TorchOptimizer {mlr3torch}R Documentation

Torch Optimizer

Description

This wraps a torch::torch_optimizer_generatora and annotates it with metadata, most importantly a ParamSet. The optimizer is created for the given parameter values by calling the ⁠$generate()⁠ method.

This class is usually used to configure the optimizer of a torch learner, e.g. when construcing a learner or in a ModelDescriptor.

For a list of available optimizers, see mlr3torch_optimizers. Items from this dictionary can be retrieved using t_opt().

Parameters

Defined by the constructor argument param_set. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter for each argument of the wrapped loss function, where the parametes are then of type ParamUty.

Super class

mlr3torch::TorchDescriptor -> TorchOptimizer

Methods

Public methods

Inherited methods

Method new()

Creates a new instance of this R6 class.

Usage
TorchOptimizer$new(
  torch_optimizer,
  param_set = NULL,
  id = NULL,
  label = NULL,
  packages = NULL,
  man = NULL
)
Arguments
torch_optimizer

(torch_optimizer_generator)
The torch optimizer.

param_set

(ParamSet or NULL)
The parameter set. If NULL (default) it is inferred from torch_optimizer.

id

(character(1))
The id for of the new object.

label

(character(1))
Label for the new instance.

packages

(character())
The R packages this object depends on.

man

(character(1))
String in the format ⁠[pkg]::[topic]⁠ pointing to a manual page for this object. The referenced help package can be opened via method ⁠$help()⁠.


Method generate()

Instantiates the optimizer.

Usage
TorchOptimizer$generate(params)
Arguments
params

(named list() of torch_tensors)
The parameters of the network.

Returns

torch_optimizer


Method clone()

The objects of this class are cloneable with this method.

Usage
TorchOptimizer$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

See Also

Other Torch Descriptor: TorchCallback, TorchDescriptor, TorchLoss, as_torch_callbacks(), as_torch_loss(), as_torch_optimizer(), mlr3torch_losses, mlr3torch_optimizers, t_clbk(), t_loss(), t_opt()

Examples


# Create a new torch loss
torch_opt = TorchOptimizer$new(optim_adam, label = "adam")
torch_opt
# If the param set is not specified, parameters are inferred but are of class ParamUty
torch_opt$param_set

# open the help page of the wrapped optimizer
# torch_opt$help()

# Retrieve an optimizer from the dictionary
torch_opt = t_opt("sgd", lr = 0.1)
torch_opt
torch_opt$param_set
torch_opt$label
torch_opt$id

# Create the optimizer for a network
net = nn_linear(10, 1)
opt = torch_opt$generate(net$parameters)

# is the same as
optim_sgd(net$parameters, lr = 0.1)

# Use in a learner
learner = lrn("regr.mlp", optimizer = t_opt("sgd"))
# The parameters of the optimizer are added to the learner's parameter set
learner$param_set


[Package mlr3torch version 0.1.0 Index]