task_dataset {mlr3torch}R Documentation

Create a Dataset from a Task

Description

Creates a torch dataset from an mlr3 Task. The resulting dataset's ⁠$.get_batch()⁠ method returns a list with elements x, y and index:

The data is returned on the device specified by the parameter device.

Usage

task_dataset(task, feature_ingress_tokens, target_batchgetter = NULL, device)

Arguments

task

(Task)
The task for which to build the dataset.

feature_ingress_tokens

(named list() of TorchIngressToken)
Each ingress token defines one item in the ⁠$x⁠ value of a batch with corresponding names.

target_batchgetter

(⁠function(data, device)⁠)
A function taking in arguments data, which is a data.table containing only the target variable, and device. It must return the target as a torch tensor on the selected device.

device

(character())
The device, e.g. "cuda" or "cpu".

Value

torch::dataset

Examples


task = tsk("iris")
sepal_ingress = TorchIngressToken(
  features = c("Sepal.Length", "Sepal.Width"),
  batchgetter = batchgetter_num,
  shape = c(NA, 2)
)
petal_ingress = TorchIngressToken(
  features = c("Petal.Length", "Petal.Width"),
  batchgetter = batchgetter_num,
  shape = c(NA, 2)
)
ingress_tokens = list(sepal = sepal_ingress, petal = petal_ingress)

target_batchgetter = function(data, device) {
  torch_tensor(data = data[[1L]], dtype = torch_float32(), device)$unsqueeze(2)
}
dataset = task_dataset(task, ingress_tokens, target_batchgetter, "cpu")
batch = dataset$.getbatch(1:10)
batch


[Package mlr3torch version 0.1.0 Index]