task_dataset {mlr3torch} | R Documentation |
Create a Dataset from a Task
Description
Creates a torch dataset from an mlr3 Task
.
The resulting dataset's $.get_batch()
method returns a list with elements x
, y
and index
:
-
x
is a list with tensors, whose content is defined by the parameterfeature_ingress_tokens
. -
y
is the target variable and its content is defined by the parametertarget_batchgetter
. -
.index
is the index of the batch in the task's data.
The data is returned on the device specified by the parameter device
.
Usage
task_dataset(task, feature_ingress_tokens, target_batchgetter = NULL, device)
Arguments
task |
|
feature_ingress_tokens |
(named |
target_batchgetter |
( |
device |
( |
Value
Examples
task = tsk("iris")
sepal_ingress = TorchIngressToken(
features = c("Sepal.Length", "Sepal.Width"),
batchgetter = batchgetter_num,
shape = c(NA, 2)
)
petal_ingress = TorchIngressToken(
features = c("Petal.Length", "Petal.Width"),
batchgetter = batchgetter_num,
shape = c(NA, 2)
)
ingress_tokens = list(sepal = sepal_ingress, petal = petal_ingress)
target_batchgetter = function(data, device) {
torch_tensor(data = data[[1L]], dtype = torch_float32(), device)$unsqueeze(2)
}
dataset = task_dataset(task, ingress_tokens, target_batchgetter, "cpu")
batch = dataset$.getbatch(1:10)
batch