metalearner_deepneural {DeepLearningCausal} | R Documentation |
metalearner_deepneural
Description
metalearner_deepneural
implements the S-learner and T-learner for estimating
CATE using Deep Neural Networks. The Resilient back propagation (Rprop)
algorithm is used for training neural networks.
Usage
metalearner_deepneural(
data,
cov.formula,
treat.var,
meta.learner.type,
stepmax = 1e+05,
nfolds = 5,
algorithm = "rprop+",
hidden.layer = c(4, 2),
linear.output = FALSE,
binary.outcome = FALSE
)
Arguments
data |
|
cov.formula |
formula description of the model y ~ x(list of covariates). |
treat.var |
string for the name of treatment variable. |
meta.learner.type |
string specifying is the S-learner and
|
stepmax |
maximum number of steps for training model. |
nfolds |
number of folds for cross-validation. Currently supports up to 5 folds. |
algorithm |
a string for the algorithm for the neural network.
Default set to |
vector of integers specifying layers and number of neurons. | |
linear.output |
logical specifying regression (TRUE) or classification (FALSE) model. |
binary.outcome |
logical specifying predicted outcome variable will take binary values or proportions. |
Value
metalearner_deepneural
of predicted outcome values and CATEs estimated by the meta
learners for each observation.
Examples
# load dataset
data(exp_data)
# estimate CATEs with S Learner
set.seed(123456)
slearner_nn <- metalearner_deepneural(cov.formula = support_war ~ age + income +
employed + job_loss,
data = exp_data,
treat.var = "strong_leader",
meta.learner.type = "S.Learner",
stepmax = 2e+9,
nfolds = 5,
algorithm = "rprop+",
hidden.layer = c(1),
linear.output = FALSE,
binary.outcome = FALSE)
print(slearner_nn)
# load dataset
set.seed(123456)
# estimate CATEs with T Learner
tlearner_nn <- metalearner_deepneural(cov.formula = support_war ~ age +
income +
employed + job_loss,
data = exp_data,
treat.var = "strong_leader",
meta.learner.type = "T.Learner",
stepmax = 1e+9,
nfolds = 5,
algorithm = "rprop+",
hidden.layer = c(2,1),
linear.output = FALSE,
binary.outcome = FALSE)
print(tlearner_nn)