ALE {cito} | R Documentation |
Accumulated Local Effect Plot (ALE)
Description
Performs an ALE for one or more features.
Usage
ALE(
model,
variable = NULL,
data = NULL,
K = 10,
ALE_type = c("equidistant", "quantile"),
plot = TRUE,
parallel = FALSE,
...
)
## S3 method for class 'citodnn'
ALE(
model,
variable = NULL,
data = NULL,
K = 10,
ALE_type = c("equidistant", "quantile"),
plot = TRUE,
parallel = FALSE,
...
)
## S3 method for class 'citodnnBootstrap'
ALE(
model,
variable = NULL,
data = NULL,
K = 10,
ALE_type = c("equidistant", "quantile"),
plot = TRUE,
parallel = FALSE,
...
)
Arguments
model |
a model created by |
variable |
variable as string for which the PDP should be done |
data |
data on which ALE is performed on, if NULL training data will be used. |
K |
number of neighborhoods original feature space gets divided into |
ALE_type |
method on how the feature space is divided into neighborhoods. |
plot |
plot ALE or not |
parallel |
parallelize over bootstrap models or not |
... |
arguments passed to |
Value
A list of plots made with 'ggplot2' consisting of an individual plot for each defined variable.
Explanation
Accumulated Local Effect plots (ALE) quantify how the predictions change when the features change. They are similar to partial dependency plots but are more robust to feature collinearity.
Mathematical details
If the defined variable is a numeric feature, the ALE is performed. Here, the non centered effect for feature j with k equally distant neighborhoods is defined as:
\hat{\tilde{f}}_{j,ALE}(x)=\sum_{k=1}^{k_j(x)}\frac{1}{n_j(k)}\sum_{i:x_{j}^{(i)}\in{}N_j(k)}\left[\hat{f}(z_{k,j},x^{(i)}_{\setminus{}j})-\hat{f}(z_{k-1,j},x^{(i)}_{\setminus{}j})\right]
Where N_j(k)
is the k-th neighborhood and n_j(k)
is the number of observations in the k-th neighborhood.
The last part of the equation,
\left[\hat{f}(z_{k,j},x^{(i)}_{\setminus{}j})-\hat{f}(z_{k-1,j},x^{(i)}_{\setminus{}j})\right]
represents the difference in model prediction when the value of feature j is exchanged with the upper and lower border of the current neighborhood.
See Also
Examples
if(torch::torch_is_installed()){
library(cito)
# Build and train Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris)
ALE(nn.fit, variable = "Petal.Length")
}