MCBoost {mcboost}R Documentation

Multi-Calibration Boosting

Description

Implements Multi-Calibration Boosting by Hebert-Johnson et al. (2018) and Multi-Accuracy Boosting by Kim et al. (2019) for the multi-calibration of a machine learning model's prediction. Multi-Calibration works best in scenarios where the underlying data & labels are unbiased but a bias is introduced within the algorithm's fitting procedure. This is often the case, e.g. when an algorithm fits a majority population while ignoring or under-fitting minority populations.
Expects initial models that fit binary outcomes or continuous outcomes with predictions that are in (or scaled to) the 0-1 range. The method defaults to ⁠Multi-Accuracy Boosting⁠ as described in Kim et al. (2019). In order to obtain behaviour as described in Hebert-Johnson et al. (2018) set multiplicative=FALSE and num_buckets to 10.

For additional details, please refer to the relevant publications:

Public fields

max_iter

integer
The maximum number of iterations of the multi-calibration/multi-accuracy method.

alpha

numeric
Accuracy parameter that determines the stopping condition.

eta

numeric
Parameter for multiplicative weight update (step size).

num_buckets

integer
The number of buckets to split into in addition to using the whole sample.

bucket_strategy

character
Currently only supports "simple", even split along probabilities. Only relevant for num_buckets > 1.

rebucket

logical
Should buckets be re-calculated at each iteration?

eval_fulldata

logical
Should auditor be evaluated on the full data?

partition

logical
True/False flag for whether to split up predictions by their "partition" (e.g., predictions less than 0.5 and predictions greater than 0.5).

multiplicative

logical
Specifies the strategy for updating the weights (multiplicative weight vs additive).

iter_sampling

character
Specifies the strategy to sample the validation data for each iteration.

auditor_fitter

AuditorFitter
Specifies the type of model used to fit the residuals.

predictor

function
Initial predictor function.

iter_models

list
Cumulative list of fitted models.

iter_partitions

list
Cumulative list of data partitions for models.

iter_corr

list
Auditor correlation in each iteration.

auditor_effects

list
Auditor effect in each iteration.

bucket_strategies

character
Possible bucket_strategies.

weight_degree

integer
Weighting degree for low-degree multi-calibration.

Methods

Public methods


Method new()

Initialize a multi-calibration instance.

Usage
MCBoost$new(
  max_iter = 5,
  alpha = 1e-04,
  eta = 1,
  num_buckets = 2,
  partition = ifelse(num_buckets > 1, TRUE, FALSE),
  bucket_strategy = "simple",
  rebucket = FALSE,
  eval_fulldata = FALSE,
  multiplicative = TRUE,
  auditor_fitter = NULL,
  subpops = NULL,
  default_model_class = ConstantPredictor,
  init_predictor = NULL,
  iter_sampling = "none",
  weight_degree = 1L
)
Arguments
max_iter

integer
The maximum number of iterations of the multi-calibration/multi-accuracy method. Default 5L.

alpha

numeric
Accuracy parameter that determines the stopping condition. Default 1e-4.

eta

numeric
Parameter for multiplicative weight update (step size). Default 1.0.

num_buckets

integer
The number of buckets to split into in addition to using the whole sample. Default 2L.

partition

logical
True/False flag for whether to split up predictions by their "partition" (e.g., predictions less than 0.5 and predictions greater than 0.5). Defaults to TRUE (multi-accuracy boosting).

bucket_strategy

character
Currently only supports "simple", even split along probabilities. Only taken into account for num_buckets > 1.

rebucket

logical
Should buckets be re-done at each iteration? Default FALSE.

eval_fulldata

logical
Should the auditor be evaluated on the full data or on the respective bucket for determining the stopping criterion? Default FALSE, auditor is only evaluated on the bucket. This setting keeps the implementation closer to the Algorithm proposed in the corresponding multi-accuracy paper (Kim et al., 2019) where auditor effects are computed across the full sample (i.e. eval_fulldata = TRUE).

multiplicative

logical
Specifies the strategy for updating the weights (multiplicative weight vs additive). Defaults to TRUE (multi-accuracy boosting). Set to FALSE for multi-calibration.

auditor_fitter

AuditorFitter|character|mlr3::Learner
Specifies the type of model used to fit the residuals. The default is RidgeAuditorFitter. Can be a character, the name of a AuditorFitter, a mlr3::Learner that is then auto-converted into a LearnerAuditorFitter or a custom AuditorFitter.

subpops

list
Specifies a collection of characteristic attributes and the values they take to define subpopulations e.g. list(age = c('20-29','30-39','40+'), nJobs = c(0,1,2,'3+'), ,..).

default_model_class

Predictor
The class of the model that should be used as the init predictor model if init_predictor is not specified. Defaults to ConstantPredictor which predicts a constant value.

init_predictor

function|mlr3::Learner
The initial predictor function to use (i.e., if the user has a pretrained model). If a mlr3 Learner is passed, it will be autoconverted using mlr3_init_predictor. This requires the mlr3::Learner to be trained.

iter_sampling

character
How to sample the validation data for each iteration? Can be bootstrap, split or none.
"split" splits the data into max_iter parts and validates on each sample in each iteration.
"bootstrap" uses a new bootstrap sample in each iteration.
"none" uses the same dataset in each iteration.

weight_degree

character
Weighting degree for low-degree multi-calibration. Initialized to 1, which applies constant weighting with 1.


Method multicalibrate()

Run multi-calibration.

Usage
MCBoost$multicalibrate(data, labels, predictor_args = NULL, audit = FALSE, ...)
Arguments
data

data.table
Features.

labels

numeric
One-hot encoded labels (of same length as data).

predictor_args

any
Arguments passed on to init_predictor. Defaults to NULL.

audit

logical
Perform auditing? Initialized to TRUE.

...

any
Params passed on to other methods.

Returns

NULL


Method predict_probs()

Predict a dataset with multi-calibrated predictions

Usage
MCBoost$predict_probs(x, t = Inf, predictor_args = NULL, audit = FALSE, ...)
Arguments
x

data.table
Prediction data.

t

integer
Number of multi-calibration steps to predict. Default: Inf (all).

predictor_args

any
Arguments passed on to init_predictor. Defaults to NULL.

audit

logical
Should audit weights be stored? Default FALSE.

...

any
Params passed on to the residual prediction model's predict method.

Returns

numeric
Numeric vector of multi-calibrated predictions.


Method auditor_effect()

Compute the auditor effect for each instance which are the cumulative absolute predictions of the auditor. It indicates "how much" each observation was affected by multi-calibration on average across iterations.

Usage
MCBoost$auditor_effect(
  x,
  aggregate = TRUE,
  t = Inf,
  predictor_args = NULL,
  ...
)
Arguments
x

data.table
Prediction data.

aggregate

logical
Should the auditor effect be aggregated across iterations? Defaults to TRUE.

t

integer
Number of multi-calibration steps to predict. Defaults to Inf (all).

predictor_args

any
Arguments passed on to init_predictor. Defaults to NULL.

...

any
Params passed on to the residual prediction model's predict method.

Returns

numeric
Numeric vector of auditor effects for each row in x.


Method print()

Prints information about multi-calibration.

Usage
MCBoost$print(...)
Arguments
...

any
Not used.


Method clone()

The objects of this class are cloneable with this method.

Usage
MCBoost$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

Examples

# See vignette for more examples.
# Instantiate the object
## Not run: 
mc = MCBoost$new()
# Run multi-calibration on training dataset.
mc$multicalibrate(iris[1:100, 1:4], factor(sample(c("A", "B"), 100, TRUE)))
# Predict on test set
mc$predict_probs(iris[101:150, 1:4])
# Get auditor effect
mc$auditor_effect(iris[101:150, 1:4])

## End(Not run)

[Package mcboost version 0.4.3 Index]