| ModelDescriptor {mlr3torch} | R Documentation |
Represent a Model with Meta-Info
Description
Represents a model; possibly a complete model, possibly one in the process of being built up.
This model takes input tensors of shapes shapes_in and
pipes them through graph. Input shapes get mapped to input channels of graph.
Output shapes are named by the output channels of graph; it is also possible
to represent no-ops on tensors, in which case names of input and output should be identical.
ModelDescriptor objects typically represent partial models being built up, in which case the pointer slot
indicates a specific point in the graph that produces a tensor of shape pointer_shape, on which the graph should
be extended.
It is allowed for the graph in this structure to be modified by-reference in different parts of the code.
However, these modifications may never add edges with elements of the Graph as destination. In particular, no
element of graph$input may be removed by reference, e.g. by adding an edge to the Graph that has the input
channel of a PipeOp that was previously without parent as its destination.
In most cases it is better to create a specific ModelDescriptor by training a Graph consisting (mostly) of
operators PipeOpTorchIngress, PipeOpTorch, PipeOpTorchLoss, PipeOpTorchOptimizer, and
PipeOpTorchCallbacks.
A ModelDescriptor can be converted to a nn_graph via model_descriptor_to_module.
Usage
ModelDescriptor(
graph,
ingress,
task,
optimizer = NULL,
loss = NULL,
callbacks = NULL,
pointer = NULL,
pointer_shape = NULL
)
Arguments
graph |
( |
ingress |
(uniquely named |
task |
( |
optimizer |
( |
loss |
( |
callbacks |
(A |
pointer |
( |
pointer_shape |
( |
Value
(ModelDescriptor)
See Also
Other Model Configuration:
mlr_pipeops_torch_callbacks,
mlr_pipeops_torch_loss,
mlr_pipeops_torch_optimizer,
model_descriptor_union()
Other Graph Network:
TorchIngressToken(),
mlr_learners_torch_model,
mlr_pipeops_module,
mlr_pipeops_torch,
mlr_pipeops_torch_ingress,
mlr_pipeops_torch_ingress_categ,
mlr_pipeops_torch_ingress_ltnsr,
mlr_pipeops_torch_ingress_num,
model_descriptor_to_learner(),
model_descriptor_to_module(),
model_descriptor_union(),
nn_graph()