| nn_graph {mlr3torch} | R Documentation |
Graph Network
Description
Represents a neural network using a Graph that usually costains mostly PipeOpModules.
Usage
nn_graph(graph, shapes_in, output_map = graph$output$name, list_output = FALSE)
Arguments
graph |
|
shapes_in |
(named |
output_map |
( |
list_output |
( |
Value
See Also
Other Graph Network:
ModelDescriptor(),
TorchIngressToken(),
mlr_learners_torch_model,
mlr_pipeops_module,
mlr_pipeops_torch,
mlr_pipeops_torch_ingress,
mlr_pipeops_torch_ingress_categ,
mlr_pipeops_torch_ingress_ltnsr,
mlr_pipeops_torch_ingress_num,
model_descriptor_to_learner(),
model_descriptor_to_module(),
model_descriptor_union()
Examples
graph = mlr3pipelines::Graph$new()
graph$add_pipeop(po("module_1", module = nn_linear(10, 20)), clone = FALSE)
graph$add_pipeop(po("module_2", module = nn_relu()), clone = FALSE)
graph$add_pipeop(po("module_3", module = nn_linear(20, 1)), clone = FALSE)
graph$add_edge("module_1", "module_2")
graph$add_edge("module_2", "module_3")
network = nn_graph(graph, shapes_in = list(module_1.input = c(NA, 10)))
x = torch_randn(16, 10)
network(module_1.input = x)
[Package mlr3torch version 0.1.0 Index]