perm_importance {hstats}R Documentation

Permutation Importance

Description

Calculates permutation importance for a set of features or a set of feature groups. By default, importance is calculated for all columns in X (except column names used as response y or as case weight w).

Usage

perm_importance(object, ...)

## Default S3 method:
perm_importance(
  object,
  X,
  y,
  v = NULL,
  pred_fun = stats::predict,
  loss = "squared_error",
  m_rep = 4L,
  agg_cols = FALSE,
  normalize = FALSE,
  n_max = 10000L,
  w = NULL,
  verbose = TRUE,
  ...
)

## S3 method for class 'ranger'
perm_importance(
  object,
  X,
  y,
  v = NULL,
  pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
  loss = "squared_error",
  m_rep = 4L,
  agg_cols = FALSE,
  normalize = FALSE,
  n_max = 10000L,
  w = NULL,
  verbose = TRUE,
  ...
)

## S3 method for class 'explainer'
perm_importance(
  object,
  X = object[["data"]],
  y = object[["y"]],
  v = NULL,
  pred_fun = object[["predict_function"]],
  loss = "squared_error",
  m_rep = 4L,
  agg_cols = FALSE,
  normalize = FALSE,
  n_max = 10000L,
  w = object[["weights"]],
  verbose = TRUE,
  ...
)

Arguments

object

Fitted model object.

...

Additional arguments passed to pred_fun(object, X, ...), for instance type = "response" in a glm() model, or reshape = TRUE in a multiclass XGBoost model.

X

A data.frame or matrix serving as background dataset.

y

Vector/matrix of the response, or the corresponding column names in X.

v

Vector of feature names, or named list of feature groups. The default (NULL) will use all column names of X with the following exception: If y or w are passed as column names, they are dropped.

pred_fun

Prediction function of the form ⁠function(object, X, ...)⁠, providing K \ge 1 predictions per row. Its first argument represents the model object, its second argument a data structure like X. Additional arguments (such as type = "response" in a GLM, or reshape = TRUE in a multiclass XGBoost model) can be passed via .... The default, stats::predict(), will work in most cases.

loss

One of "squared_error", "logloss", "mlogloss", "poisson", "gamma", or "absolute_error". Alternatively, a loss function can be provided that turns observed and predicted values into a numeric vector or matrix of unit losses of the same length as X. For "mlogloss", the response y can either be a dummy matrix or a discrete vector. The latter case is handled via a fast version of model.matrix(~ as.factor(y) + 0). For "squared_error", the response can be a factor with levels in column order of the predictions. In this case, squared error is evaluated for each one-hot-encoded column.

m_rep

Number of permutations (default 4).

agg_cols

Should multivariate losses be summed up? Default is FALSE. In combination with the squared error loss, agg_cols = TRUE gives the Brier score for (probabilistic) classification.

normalize

Should importance statistics be divided by average loss? Default is FALSE. If TRUE, an importance of 1 means that the average loss has been doubled by shuffling that feature's column.

n_max

If X has more than n_max rows, a random sample of n_max rows is selected from X. In this case, set a random seed for reproducibility.

w

Optional vector of case weights. Can also be a column name of X.

verbose

Should a progress bar be shown? The default is TRUE.

Details

The permutation importance of a feature is defined as the increase in the average loss when shuffling the corresponding feature values before calculating predictions. By default, the process is repeated m_rep = 4 times, and the results are averaged. In most of the cases, importance values should be derived from an independent test data set. Set normalize = TRUE to get relative increases in average loss.

Value

An object of class "hstats_matrix" containing these elements:

Methods (by class)

Losses

The default loss is the "squared_error". Other choices:

References

Fisher A., Rudin C., Dominici F. (2018). All Models are Wrong but many are Useful: Variable Importance for Black-Box, Proprietary, or Misspecified Prediction Models, using Model Class Reliance. Arxiv.

Examples

# MODEL 1: Linear regression
fit <- lm(Sepal.Length ~ ., data = iris)
s <- perm_importance(fit, X = iris, y = "Sepal.Length")

s
s$M
s$SE  # Standard errors are available thanks to repeated shuffling
plot(s)
plot(s, err_type = "SD")  # Standard deviations instead of standard errors

# Groups of features can be passed as named list
v <- list(petal = c("Petal.Length", "Petal.Width"), species = "Species")
s <- perm_importance(fit, X = iris, y = "Sepal.Length", v = v, verbose = FALSE)
s
plot(s)

# MODEL 2: Multi-response linear regression
fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
s <- perm_importance(fit, X = iris[, 3:5], y = iris[, 1:2], normalize = TRUE)
s
plot(s)
plot(s, swap_dim = TRUE, top_m = 2)

[Package hstats version 1.2.0 Index]