tabnet_explain {tabnet}R Documentation

Interpretation metrics from a TabNet model

Description

Interpretation metrics from a TabNet model

Usage

tabnet_explain(object, new_data)

## Default S3 method:
tabnet_explain(object, new_data)

## S3 method for class 'tabnet_fit'
tabnet_explain(object, new_data)

## S3 method for class 'tabnet_pretrain'
tabnet_explain(object, new_data)

## S3 method for class 'model_fit'
tabnet_explain(object, new_data)

Arguments

object

a TabNet fit object

new_data

a data.frame to obtain interpretation metrics.

Value

Returns a list with

Examples



set.seed(2021)

n <- 1000
x <- data.frame(
  x = rnorm(n),
  y = rnorm(n),
  z = rnorm(n)
)

y <- x$x

fit <- tabnet_fit(x, y, epochs = 20,
                  num_steps = 1,
                  batch_size = 512,
                  attention_width = 1,
                  num_shared = 1,
                  num_independent = 1)


 ex <- tabnet_explain(fit, x)



[Package tabnet version 0.6.0 Index]