predict.multi_arm_causal_forest {grf}R Documentation

Predict with a multi arm causal forest


Gets estimates of contrasts tau_k(x) using a trained multi arm causal forest (k = 1,...,K-1 where K is the number of treatments).


## S3 method for class 'multi_arm_causal_forest'
  newdata = NULL,
  num.threads = NULL,
  estimate.variance = FALSE,
  drop = FALSE,



The trained forest.


Points at which predictions should be made. If NULL, makes out-of-bag predictions on the training set instead (i.e., provides predictions at Xi using only trees that did not use the i-th training example). Note that this matrix should have the number of columns as the training matrix, and that the columns must appear in the same order.


Number of threads used in training. If set to NULL, the software automatically selects an appropriate amount.


Whether variance estimates for τ^(x)\hat\tau(x) are desired (for confidence intervals). This option is currently only supported for univariate outcomes Y.


If TRUE, coerce the prediction result to the lowest possible dimension. Default is FALSE.


Additional arguments (currently ignored).


A list with elements 'predictions': a 3d array of dimension [num.samples, K-1, M] with predictions for each contrast, for each outcome 1,..,M (singleton dimensions in this array can be dropped by passing the 'drop' argument to '[', or with the shorthand '$predictions[,,]'), and optionally 'variance.estimates': a matrix with K-1 columns with variance estimates for each contrast.


# Train a multi arm causal forest.
n <- 500
p <- 10
X <- matrix(rnorm(n * p), n, p)
W <- as.factor(sample(c("A", "B", "C"), n, replace = TRUE))
Y <- X[, 1] + X[, 2] * (W == "B") - 1.5 * X[, 2] * (W == "C") + rnorm(n)
mc.forest <- multi_arm_causal_forest(X, Y, W)

# Predict contrasts (out-of-bag) using the forest.
# Fitting several outcomes jointly is supported, and the returned prediction array has
# dimension [num.samples, num.contrasts, num.outcomes]. Since num.outcomes is one in
# this example, we use drop = TRUE to ignore this singleton dimension.
mc.pred <- predict(mc.forest, drop = TRUE)

# By default, the first ordinal treatment is used as baseline ("A" in this example),
# giving two contrasts tau_B = Y(B) - Y(A), tau_C = Y(C) - Y(A)
tau.hat <- mc.pred$predictions

plot(X[, 2], tau.hat[, "B - A"], ylab = "tau.contrast")
abline(0, 1, col = "red")
points(X[, 2], tau.hat[, "C - A"], col = "blue")
abline(0, -1.5, col = "red")
legend("topleft", c("B - A", "C - A"), col = c("black", "blue"), pch = 19)

# The average treatment effect of the arms with "A" as baseline.

# The conditional response surfaces mu_k(X) for a single outcome can be reconstructed from
# the contrasts tau_k(x), the treatment propensities e_k(x), and the conditional mean m(x).
# Given treatment "A" as baseline we have:
# m(x) := E[Y | X] = E[Y(A) | X] + E[W_B (Y(B) - Y(A))] + E[W_C (Y(C) - Y(A))]
# which given unconfoundedness is equal to:
# m(x) = mu(A, x) + e_B(x) tau_B(X) + e_C(x) tau_C(x)
# Rearranging and plugging in the above expressions, we obtain the following estimates
# * mu(A, x) = m(x) - e_B(x) tau_B(x) - e_C(x) tau_C(x)
# * mu(B, x) = m(x) + (1 - e_B(x)) tau_B(x) - e_C(x) tau_C(x)
# * mu(C, x) = m(x) - e_B(x) tau_B(x) + (1 - e_C(x)) tau_C(x)
Y.hat <- mc.forest$Y.hat
W.hat <- mc.forest$W.hat

muA <- Y.hat - W.hat[, "B"] * tau.hat[, "B - A"] - W.hat[, "C"] * tau.hat[, "C - A"]
muB <- Y.hat + (1 - W.hat[, "B"]) * tau.hat[, "B - A"] - W.hat[, "C"] * tau.hat[, "C - A"]
muC <- Y.hat - W.hat[, "B"] * tau.hat[, "B - A"] + (1 - W.hat[, "C"]) * tau.hat[, "C - A"]

# These can also be obtained with some array manipulations.
# (the first column is always the baseline arm)
Y.hat.baseline <- Y.hat - rowSums(W.hat[, -1, drop = FALSE] * tau.hat)
mu.hat.matrix <- cbind(Y.hat.baseline, c(Y.hat.baseline) + tau.hat)
colnames(mu.hat.matrix) <- levels(W)

# The reference level for contrast prediction can be changed with `relevel`.
# Fit and predict with treatment B as baseline:
W <- relevel(W, ref = "B")
mc.forest.B <- multi_arm_causal_forest(X, Y, W)

[Package grf version 2.3.2 Index]