fairness {R6causal} | R Documentation |
Checking fairness of a prediction via counterfactual simulation
Description
Checking fairness of a prediction via counterfactual simulation
Usage
fairness(
modellist,
scm,
sensitive,
condition,
condition_type,
parents,
n,
sens_values = NULL,
modeltype = "predict",
method,
control = NULL,
...
)
Arguments
modellist |
A list of model objects that have a predict method or a list of functions that return predictions |
scm |
An SCM object |
sensitive |
A character vector of the names of sensitive variables |
condition |
A data.table consisting of the valid rows ( e.g. data.table::data.table( x = 0, y = 0)) |
condition_type |
(required only if method == "u_find") A character vector giving the type ("continuous" or "discrete") of every variable in |
parents |
A character vector of the names of variables that remain fixed |
n |
The number of rows in the data to be simulated by |
sens_values |
A data.table specifying the combinations of the values of sensitive variables to be considered (default NULL meaning the all possible combinations of the values of sensitive variables) |
modeltype |
"predict" (default) or "function" depending on the type |
method |
The simulation method, "u_find", "rejection" or "analytic_linear_gaussian" |
control |
List of parameters to be passed to the simulation method, see |
... |
Other arguments passed to |
Value
A list containing a data table for element of modellist
. Each data table contains the predicted values after counterfactual interventions on the sensitive variables.
Examples
trainingd <- backdoor$simulate(10000, return_simdata = TRUE)
newd <- backdoor$simulate(100, return_simdata = TRUE)
vnames <- backdoor$vnames
m1 <- lm(y ~ x + z, data = trainingd)
m2 <- lm(y ~ z, data = trainingd)
fairlist <- fairness(modellist = list(m1,m2),
scm = backdoor,
sensitive = c("x"),
sens_values = data.table::data.table(x=c(0,1)),
condition = newd[1,c("x","y")],
condition_type = list(x = "cont",
z = "cont",
y = "cont"),
parents = NULL,
n = 20,
modeltype = "predict",
method = "u_find")