CATESurface {EpiForsk}R Documentation

Calculate CATE on a surface in the covariate space

Description

Calculates CATE estimates from a causal forest object on a specified surface within the covariate space.

Usage

CATESurface(
  forest,
  continuous_covariates,
  discrete_covariates,
  estimate_variance = TRUE,
  grid = 100,
  fixed_covariate_fct = median,
  other_discrete = NULL,
  max_predict_size = 1e+05,
  num_threads = 2
)

Arguments

forest

An object of class causal_forest, as returned by causal_forest(). Alternatively, and object of class regression_forest, as returned by regression_forest().

continuous_covariates

character, continuous covariates to use for the surface. Must match names in forest$X.orig.

discrete_covariates

character, discrete covariates to use for the surface. Note that discrete covariates are currently assumed to be one-hot encoded with columns named ⁠{fct_nm}_{lvl_nm}⁠. Names supplied to discrete_covariates should match fct_nm.

estimate_variance

boolean, If TRUE, the variance of CATE estimates is computed.

grid

list, points in which to predict CATE along continuous covariates. Index i in the list should contain a numeric vectors with either a single integer, specifying the number of equally spaced points within the range of the i'th continuous covariate in which to calculate the CATE, or a numeric vector with manually specified points in which to calculate the CATE along the i'th continuous covariate. If all elements of grid specify a number of points, this can be supplied using a numeric vector. If the list is named, the names must match the continuous covariates. grid will be reordered to match the order of continuous_covariates.

fixed_covariate_fct

Function applied to covariates not in the sub-surface which returns the fixed value of the covariate used to calculate the CATE. Must be specified in one of the following ways:

  • A named function, e.g. mean.

  • An anonymous function, e.g. \(x) x + 1 or function(x) x + 1.

  • A formula, e.g. ~ .x + 1. You must use .x to refer to the first argument. Only recommended if you require backward compatibility with older versions of R.

  • A string, integer, or list, e.g. "idx", 1, or list("idx", 1) which are shorthand for \(x) purrr::pluck(x, "idx"), \(x) purrr::pluck(x, 1), and \(x) purrr::pluck(x, "idx", 1) respectively. Optionally supply .default to set a default value if the indexed element is NULL or does not exist.

other_discrete

A data frame, data frame extension (e.g. a tibble), or a lazy data frame (e.g. from dbplyr or dtplyr) with columns covs and lvl. Used to specify the level of each discrete covariate to use when calculating the CATE. assumes the use of one-hot encoding. covs must contain the name of discrete covariates, and lvl the level to use. Set to NULL if none of the fixed covariates are discrete using one-hot-encoding.

max_predict_size

integer, maximum number of examples to predict at a time. If the surface has more points than max_predict_size, the prediction is split up into an appropriate number of chunks.

num_threads

Number of threads used in training. If set to NULL, the software automatically selects an appropriate amount.

Value

Tibble with the predicted CATE's on the specified surface in the covariate space. The tibble has columns for each covariate used to train the input forest, as well as columns output from predict.causal_forest().

Author(s)

KIJA

Examples


n <- 1000
p <- 3
X <- matrix(rnorm(n * p), n, p) |> as.data.frame()
X_d <- data.frame(
  X_d1 = factor(sample(1:3, n, replace = TRUE)),
  X_d2 = factor(sample(1:3, n, replace = TRUE))
)
X_d <- DiscreteCovariatesToOneHot(X_d)
X <- cbind(X, X_d)
W <- rbinom(n, 1, 0.5)
event_prob <- 1 / (1 + exp(2 * (pmax(2 * X[, 1], 0) * W - X[, 2])))
Y <- rbinom(n, 1, event_prob)
cf <- grf::causal_forest(X, Y, W)
cate_surface <- CATESurface(
  cf,
  continuous_covariates = paste0("V", 1:2),
  discrete_covariates = "X_d1",
  grid = list(
    V1 = 10,
    V2 = -5:5
  ),
  other_discrete = data.frame(
    covs = "X_d2",
    lvl = "4"
  )
)



[Package EpiForsk version 0.1.1 Index]