ale {ale} | R Documentation |
Create and return ALE data, statistics, and plots
Description
ale()
is the central function that manages the creation of ALE data and plots
for one-way ALE. For two-way interactions, see ale_ixn()
. This function calls
ale_core
(a non-exported function) that manages the ALE data and plot creation in detail. For details, see
the introductory vignette for this package or the details and examples below.
Usage
ale(
data,
model,
x_cols = NULL,
y_col = NULL,
...,
parallel = parallel::detectCores(logical = FALSE) - 1,
model_packages = as.character(NA),
output = c("plots", "data", "stats", "conf_regions"),
pred_fun = function(object, newdata, type = pred_type) {
stats::predict(object =
object, newdata = newdata, type = type)
},
pred_type = "response",
p_values = NULL,
p_alpha = c(0.01, 0.05),
x_intervals = 100,
boot_it = 0,
seed = 0,
boot_alpha = 0.05,
boot_centre = "mean",
relative_y = "median",
y_type = NULL,
median_band_pct = c(0.05, 0.5),
rug_sample_size = 500,
min_rug_per_interval = 1,
ale_xs = NULL,
ale_ns = NULL,
compact_plots = FALSE,
silent = FALSE
)
Arguments
data |
dataframe. Dataset from which to create predictions for the ALE. |
model |
model object. Model for which ALE should be calculated. May be any kind of R object that can make predictions from data. |
x_cols |
character. Vector of column names from |
y_col |
character length 1. Name of the outcome target label (y) variable.
If not provided, |
... |
not used. Inserted to require explicit naming of subsequent arguments. |
parallel |
non-negative integer length 1. Number of parallel threads (workers or tasks) for parallel execution of the function. See details. |
model_packages |
character. Character vector of names of
packages that |
output |
character in c('plots', 'data', 'stats', 'conf_regions'). Vector of types of results to return. 'plots' will return an ALE plot; 'data' will return the source ALE data; 'stats' will return ALE statistics. Each option must be listed to return the specified component. By default, all are returned. |
pred_fun , pred_type |
function,character length 1. |
p_values |
instructions for calculating p-values and to determine the
median band. If |
p_alpha |
numeric length 2 from 0 to 1. Alpha for "confidence interval" ranges
for printing bands around the median for single-variable plots. These are the
default values used if |
x_intervals |
positive integer length 1. Maximum number of intervals on the x-axis
for the ALE data for each column in |
boot_it |
non-negative integer length 1. Number of bootstrap iterations for the
ALE values. If |
seed |
integer length 1. Random seed. Supply this between runs to assure that identical random ALE data is generated each time |
boot_alpha |
numeric length 1 from 0 to 1. Alpha for percentile-based confidence
interval range for the bootstrap intervals; the bootstrap confidence intervals
will be the lowest and highest |
boot_centre |
character length 1 in c('mean', 'median'). When bootstrapping, the
main estimate for |
relative_y |
character length 1 in c('median', 'mean', 'zero'). The ale_y values will
be adjusted relative to this value. 'median' is the default. 'zero' will maintain the
default of |
y_type |
character length 1. Datatype of the y (outcome) variable. Must be one of c('binary', 'numeric', 'multinomial', 'ordinal'). Normally determined automatically; only provide for complex non-standard models that require it. |
median_band_pct |
numeric length 2 from 0 to 1. Alpha for "confidence interval" ranges
for printing bands around the median for single-variable plots. These are the
default values used if |
rug_sample_size , min_rug_per_interval |
single non-negative integer length 1.
Rug plots are normally
down-sampled otherwise they are too slow. |
ale_xs , ale_ns |
list of ale_x and ale_n vectors. If provided, these vectors will be used to
set the intervals of the ALE x axis for each variable. By default (NULL), the
function automatically calculates the ale_x intervals. |
compact_plots |
logical length 1, default |
silent |
logical length 1, default |
Details
ale_core.R
Core functions for the ale package: ale, ale_ixn, and ale_core
Value
list with the following elements:
-
data
: a list whose elements, named by each requested x variable, are each a tibble with the following columns:-
ale_x
: the values of each of the ALE x intervals or categories. -
ale_n
: the number of rows of data in eachale_x
interval or category. -
ale_y
: the ALE function value calculated for that interval or category. For bootstrapped ALE, this is the same asale_y_mean
by default orale_y_median
if theboot_centre = 'median'
argument is specified. Regardless, bothale_y_mean
andale_y_median
are returned as columns here. -
ale_y_lo
,ale_y_hi
: the lower and upper confidence intervals, respectively, for the bootstrappedale_y
value. Note: regardless what options are requested in theoutput
argument, thisdata
element is always returned.
-
-
stats
: ifstats
are requested in theoutput
argument (as is the default), returns a list. If not requested, returnsNULL
. The returned list provides ALE statistics of thedata
element duplicated and presented from various perspectives in the following elements:-
by_term
: a list named by each requested x variable, each of whose elements is a tibble with the following columns:-
statistic
: the ALE statistic specified in the row (see theby_statistic
element below). -
estimate
: the bootstrappedmean
ormedian
of thestatistic
, depending on theboot_centre
argument to theale()
function. Regardless, bothmean
andmedian
are returned as columns here. -
conf.low
,conf.high
: the lower and upper confidence intervals, respectively, for the bootstrappedestimate
.
-
-
by_statistic
: list named by each of the following ALE statistics:aled
,aler_min
,aler_max
,naled
,naler_min
,naler_max
. Seevignette('ale-statistics')
for details. -
estimate
: a tibble whose data consists of theestimate
values from theby_term
element above. The columns areterm
(the variable name) and the statistic for which the estimate is given:aled
,aler_min
,aler_max
,naled
,naler_min
,naler_max
. -
effects_plot
: aggplot
object which is the ALE effects plot for all the x variables.
-
-
plots
: ifplots
are requested in theoutput
argument (as is the default), returns a list whose elements, named by each requested x variable, are each aggplot
object of the ALE y values plotted against the x variable intervals. Ifplots
is not included inoutput
, this element isNULL
. -
conf_regions
: ifconf_regions
are requested in theoutput
argument (as is the default), returns a list. If not requested, returnsNULL
. The returned list provides summaries of the confidence regions of the relevant ALE statistics of thedata
element. The list has the following elements:-
by_term
: a list named by each requested x variable, each of whose elements is a tibble with the relevant data for the confidence regions. (Seevignette('ale-statistics')
for details about confidence regions.) -
significant
: a tibble that summarizes theby_term
to only show confidence regions that are statistically significant. Its columns are those fromby_term
plus aterm
column to specify which x variable is indicated by the respective row. -
sig_criterion
: a length-one character vector that reports which values were used to determine statistical significance: ifp_values
was provided to theale()
function, it will be used; otherwise,median_band_pct
will be used.
-
Various values echoed from the original call to the
ale()
function, provided to document the key elements used to calculate the ALE data, statistics, and plots:y_col
,x_cols
,boot_it
,seed
,boot_alpha
,boot_centre
,relative_y
,y_type
,median_band_pct
,rug_sample_size
. These are either the values provided by the user or used by default if the user did not change them.-
y_summary
: summary statistics of y values used for the ALE calculation. These statistics are based on the actual values ofy_col
unless ify_type
is a probability or other value that is constrained in the[0, 1]
range. In that case,y_summary
is based on the predicted values ofy_col
by applyingmodel
to thedata
.y_summary
is a named numeric vector. Most of the elements are the percentile of the y values. E.g., the '5%' element is the 5th percentile of y values. The following elements have special meanings:The first element is named either
p
orq
and its value is always 0. The value is not used; only the name of the element is meaningful.p
means that the following specialy_summary
elements are based on the providedp_values
object.q
means that quantiles were calculated based onmedian_band_pct
becausep_values
was not provided.-
min
,mean
,max
: the minimum, mean, and maximum y values, respectively. Note that the median is50%
, the 50th percentile. -
med_lo_2
,med_lo
,med_hi
,med_hi_2
:med_lo
andmed_hi
are the inner lower and upper confidence intervals of y values with respect to the median (50%
);med_lo_2
andmed_hi_2
are the outer confidence intervals. See the documentation for thep_alpha
andmedian_band_pct
arguments to understand how these are determined.
Custom predict function
The calculation of ALE requires modifying several values of the original
data
. Thus, ale()
needs direct access to a predict
function that work on
model
. By default, ale()
uses a generic default predict
function of the form
predict(object, newdata, type)
with the default prediction type of 'response'.
If, however, the desired prediction values are not generated with that format,
the user must specify what they want. Most of the time, the only modification needed is
to change the prediction type to some other value by setting the pred_type
argument
(e.g., to 'prob' to generated classification probabilities). But if the desired
predictions need a different function signature, then the user must create a
custom prediction function and pass it to pred_fun
. The requirements for this
custom function are:
It must take three required arguments and nothing else:
-
object
: a model -
newdata
: a dataframe or compatible table type -
type
: a string; it should usually be specified astype = pred_type
These argument names are according to the R convention for the generic stats::predict function.
-
It must return a vector of numeric values as the prediction.
You can see an example below of a custom prediction function.
Note: survival
models probably do not need a custom prediction function
but y_col
must be set to the name of the binary event column and
pred_type
must be set to the desired prediction type.
ALE statistics
For details about the ALE-based statistics (ALED, ALER, NALED, and NALER), see
vignette('ale-statistics')
.
Parallel processing
Parallel processing using the {furrr}
library is enabled by default. By default,
it will use all the available physical
CPU cores (minus the core being used for the current R session) with the setting
parallel = parallel::detectCores(logical = FALSE) - 1
. Note that only
physical cores are used (not logical cores or "hyperthreading") because
machine learning can only take advantage of the floating point processors on
physical cores, which are absent from logical cores. Trying to use logical
cores will not speed up processing and might actually slow it down with useless
data transfer. If you will dedicate
the entire computer to running this function (and you don't mind everything
else becoming very slow while it runs), you may use all cores by setting
parallel = parallel::detectCores(logical = FALSE)
. To disable parallel
processing, set parallel = 0
.
Progress bars
Progress bars are implemented with the {progressr}
package, which lets
the user fully control progress bars. To disable progress bars, set silent = TRUE
.
The first time a function is called in
the {ale}
package that requires progress bars, it checks if the user has
activated the necessary {progressr}
settings. If not, the {ale}
package
automatically enables {progressr}
progress bars with the cli
handler and
prints a message notifying the user.
If you like the default progress bars and you want to make them permanent, then you can add the following lines of code to your .Rprofile configuration file and they will become your defaults for every R session; you will not see the message again:
progressr::handlers(global = TRUE) progressr::handlers('cli')
For more details on formatting progress bars to your liking, see the introduction
to the {progressr}
package.
References
Okoli, Chitu. 2023. “Statistical Inference Using Machine Learning and Classical Techniques Based on Accumulated Local Effects (ALE).” arXiv. https://arxiv.org/abs/2310.09877.
Examples
set.seed(0)
diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]
# Create a GAM model with flexible curves to predict diamond price
# Smooth all numeric variables and include all other variables
gam_diamonds <- mgcv::gam(
price ~ s(carat) + s(depth) + s(table) + s(x) + s(y) + s(z) +
cut + color + clarity,
data = diamonds_sample
)
summary(gam_diamonds)
# Simple ALE without bootstrapping
ale_gam_diamonds <- ale(
diamonds_sample, gam_diamonds,
parallel = 2 # CRAN limit (delete this line on your own computer)
)
# Plot the ALE data
ale_gam_diamonds$plots |>
patchwork::wrap_plots()
# Bootstrapped ALE
# This can be slow, since bootstrapping runs the algorithm boot_it times
# Create ALE with 100 bootstrap samples
ale_gam_diamonds_boot <- ale(
diamonds_sample, gam_diamonds, boot_it = 100,
parallel = 2 # CRAN limit (delete this line on your own computer)
)
# Bootstrapped ALEs print with confidence intervals
ale_gam_diamonds_boot$plots |>
patchwork::wrap_plots()
# If the predict function you want is non-standard, you may define a
# custom predict function. It must return a single numeric vector.
custom_predict <- function(object, newdata, type = pred_type) {
predict(object, newdata, type = type, se.fit = TRUE)$fit
}
ale_gam_diamonds_custom <- ale(
diamonds_sample, gam_diamonds,
pred_fun = custom_predict, pred_type = 'link',
parallel = 2 # CRAN limit (delete this line on your own computer)
)
# Plot the ALE data
ale_gam_diamonds_custom$plots |>
patchwork::wrap_plots()