Interpreter {distillML} | R Documentation |
Interpreter class description
Description
A wrapper class based on a predictor object for examining the predictions of the model with respect to one or two features. The two methods for interpreting a model based on one or two features are partial dependence plots (PDP), which averages over the marginal distribution of the predictions of the model, and accumulated local effects (ALE) functions which averages over the conditional distribution of the predictions of the model.
The only necessary argument is the Predictor object. The other arguments are optional, but it may be useful to specify the number of samples or the specific data points (data.points) if the training data is very large. This can greatly reduce the time for computation.
For the output, the model returns an interpreter object with two lists of functions: one for interpreting a single feature's role in the black-box model, and the other for intepreting a pair of features' role in the black-box model. These interpretability functions are built for each possible feature (or pair of features). Each of these functions return a vector of averaged predictions equal in length to the number of values (or number of rows) input into the function.
Public fields
predictor
The Predictor object that contains the model that the user wants to query. This is the only parameter that is required to initialize an Interpreter object. All entries in the vector must match column names from the 'data' parameter of the Predictor object.
features
An optional list of single features that we want to create PDP functions for.
features.2d
A two column data frame that contains pairs of names that we want to create 2D PDP functions for. All entries in the data frame must match column names from the 'data' parameter of the Predictor object.
data.points
A vector of indices of data points in the training data frame to be used as the observations for creating the PDP/ICE/ALE plots. When the training data is large, it can greatly reduce the required computation to pass only a downsampled subset of the training data to the pdp function construction. Alternatively, if one is only interested understanding the model predictions for a specific subgroup, the indices of the observations in the given subgroup can be passed here.
pdp.1d
A List of functions giving single feature PDP interpretations of the model.
pdp.2d
A List of functions giving two-feature PDP interpretations of the model
feat.class
A vector that contains the class for each feature (categorical or continuous)
center.at
The value(s) to center the feature plots at. A list of equal length to the length of the features.
grid.points
A list of vectors containing the grid points to use for the predictions for PDP and ICE plots. For ALE plots, we use quantile-based methods that depend on the distribution of the training data.
grid.size
The number of grid points to plot for a continuous feature. This parameter sets the number of grid points for PDP, ICE, and ALE plots.
saved
A list that caches the previous calculations for the 1-D ICE plots, 1-D PDP plots, 2-D PDP plots, and grid points for building the distilled model. This saves the uncentered calculations.
ale.grid
A list that caches the saved predictions for the ALE plots
Methods
Public methods
Method new()
Usage
Interpreter$new( predictor = NULL, samples = 1000, data.points = NULL, grid.size = 50 )
Arguments
predictor
The Predictor object that contains the model that the user wants to query. This is the only parameter that is required to initialize an Interpreter object. All entries in the vector must match column names from the 'data' parameter of the Predictor object.
samples
The number of observations used for the interpretability method. If no number is given, the default set is the minimum between 1000 and the number of rows in the training data set. Rows with missing values are excluded from being sampled.
data.points
The indices of the data points used for the PDP/ALE. This overwrites the "samples" parameter above.
grid.size
The number of grid points used to create for the PDP, ICE, and ALE plots for each feature.
Returns
An 'Interpreter' object.
Method clone()
The objects of this class are cloneable with this method.
Usage
Interpreter$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
Note
The class that wraps a Predictor object for application of different interpretability methods. For usage examples, please refer to the README document.
Examples
library(distillML)
library(Rforestry)
set.seed(491)
data <- iris
test_ind <- sample(1:nrow(data), nrow(data)%/%5)
train_reg <- data[-test_ind,]
test_reg <- data[test_ind,]
forest <- forestry(x=data[,-1],
y=data[,1])
forest_predictor <- Predictor$new(model = forest, data=train_reg,
y="Sepal.Length", task = "regression")
forest_interpret <- Interpreter$new(predictor = forest_predictor)