tabnet_fit {tabnet} | R Documentation |
Tabnet model
Description
Fits the TabNet: Attentive Interpretable Tabular Learning model
Usage
tabnet_fit(x, ...)
## Default S3 method:
tabnet_fit(x, ...)
## S3 method for class 'data.frame'
tabnet_fit(
x,
y,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL,
weights = NULL
)
## S3 method for class 'formula'
tabnet_fit(
formula,
data,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL,
weights = NULL
)
## S3 method for class 'recipe'
tabnet_fit(
x,
data,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL,
weights = NULL
)
## S3 method for class 'Node'
tabnet_fit(
x,
tabnet_model = NULL,
config = tabnet_config(),
...,
from_epoch = NULL
)
Arguments
x |
Depending on the context:
The predictor data should be standardized (e.g. centered or scaled). The model treats categorical predictors internally thus, you don't need to make any treatment. |
... |
Model hyperparameters.
Any hyperparameters set here will update those set by the config argument.
See |
y |
When
|
tabnet_model |
A previously fitted TabNet model object to continue the fitting on.
if |
config |
A set of hyperparameters created using the |
from_epoch |
When a |
weights |
Unused. |
formula |
A formula specifying the outcome terms on the left-hand side, and the predictor terms on the right-hand side. |
data |
When a recipe or formula is used,
|
Value
A TabNet model object. It can be used for serialization, predictions, or further fitting.
Fitting a pre-trained model
When providing a parent tabnet_model
parameter, the model fitting resumes from that model weights
at the following epoch:
last fitted epoch for a model already in torch context
Last model checkpoint epoch for a model loaded from file
the epoch related to a checkpoint matching or preceding the
from_epoch
value if provided The model fitting metrics append on top of the parent metrics in the returned TabNet model.
Multi-outcome
TabNet allows multi-outcome prediction, which is usually named multi-label classification or multi-output classification when outcomes are categorical. Multi-outcome currently expect outcomes to be either all numeric or all categorical.
Threading
TabNet uses torch
as its backend for computation and torch
uses all
available threads by default.
You can control the number of threads used by torch
with:
torch::torch_set_num_threads(1) torch::torch_set_num_interop_threads(1)
Examples
data("ames", package = "modeldata")
data("attrition", package = "modeldata")
ids <- sample(nrow(attrition), 256)
## Single-outcome regression using formula specification
fit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 1)
## Single-outcome classification using data-frame specification
attrition_x <- attrition[,-which(names(attrition) == "Attrition")]
fit <- tabnet_fit(attrition_x, attrition$Attrition, epochs = 1, verbose = TRUE)
## Multi-outcome regression on `Sale_Price` and `Pool_Area` in `ames` dataset using formula,
ames_fit <- tabnet_fit(Sale_Price + Pool_Area ~ ., data = ames[ids,], epochs = 2, valid_split = 0.2)
## Multi-label classification on `Attrition` and `JobSatisfaction` in
## `attrition` dataset using recipe
library(recipes)
rec <- recipe(Attrition + JobSatisfaction ~ ., data = attrition[ids,]) %>%
step_normalize(all_numeric(), -all_outcomes())
attrition_fit <- tabnet_fit(rec, data = attrition[ids,], epochs = 2, valid_split = 0.2)
## Hierarchical classification on `acme`
data(acme, package = "data.tree")
acme_fit <- tabnet_fit(acme, epochs = 2, verbose = TRUE)
# Note: Dataset number of rows and model number of epochs should be increased
# for publication-level results.