submod_train {StratifiedMedicine} | R Documentation |
Subgroup Identification: Train Model
Description
Wrapper function to train a subgroup model (submod). Outputs subgroup assignments and fitted model.
Usage
submod_train(
Y,
A,
X,
Xtest = NULL,
mu_train = NULL,
family = "gaussian",
submod = "lmtree",
hyper = NULL,
ple = "ranger",
ple.hyper = NULL,
meta = ifelse(family == "survival", "T-learner", "X-learner"),
propensity = FALSE,
pool = "no",
delta = ">0",
param = NULL,
resample = NULL,
R = 20,
resample_pool = NULL,
R_pool = 20,
stratify = ifelse(!is.null(A), "trt", "no"),
combine = "SS",
alpha_ovrl = 0.05,
alpha_s = 0.05,
verbose.resamp = FALSE,
efficient = FALSE,
...
)
Arguments
Y |
The outcome variable. Must be numeric or survival (ex; Surv(time,cens) ) |
A |
Treatment variable. (Default supports binary treatment, either numeric or factor). "ple_train" accomodates >2 along with binary treatments. |
X |
Covariate space. |
Xtest |
Test set. Default is NULL (no test predictions). Variable types should match X. |
mu_train |
Patient-level estimates in training set (see |
family |
Outcome type. Options include "gaussion" (default), "binomial", and "survival". |
submod |
Subgroup identification model function. Options include tree-methods that target the treatment by variable interaction directly ("lmtree", "glmtree", "mob_weib"), regress the CATE ("rpart_cate", "ctree_cate"), and target prognostic variables ("rpart", "ctree"). Default for family="gaussian" is "lmtree" (MOB with OLS loss). For "binomial" the default is "glmtree" (MOB with binomial loss). Default for "survival" is "lmtree" (log-rank transformation on survival outcomes and then fit MOB-OLS). "None" uses no submod. Currently only available for binary treatments or A=NULL. |
hyper |
Hyper-parameters for submod (must be list). Default is NULL. |
ple |
Base-learner used to estimate patient-level equantities, such as the conditional average treatment effect (CATE), E(Y|A=1,X)-E(Y|A=0, X) = CATE(X). Default is random based based through "ranger". "None" uses no ple. See below for details on estimating the treatment contrasts. |
ple.hyper |
Hyper-parameters for the PLE function (must be list). Default is NULL. |
meta |
Using the ple model as a base learner, meta-learners can be used for estimating patient-level treatment differences. Options include "T-learner" (treatment specific models), "S-learner" (single model), and "X-learner". For family="gaussian" & "binomial", the default is "X-learner", which uses a two-stage regression approach (See Kunzel et al 2019). For "survival", the default is "T-learner". "X-learner" is currently not supported for survival outcomes. |
propensity |
Propensity score estimation, P(A=a|X). Default=FALSE which use the marginal estimates, P(A=a) (applicable for RCT data). If TRUE, will use the "ple" base learner to estimate P(A=a|X). |
pool |
Whether to pool the initial identified subgroups (ex: tree nodes). Default = "no". Other options include "trteff" or "trteff_boot" (check if naive or bootstrap treatment estimate is beyond clinical meaningful threshold delta, ex: trteff_boot > 0), and optimal treatment regime (OTR) pooling, "otr:logistic", "otr:rf". "otr:logistic" fits weighted logistic regression with I(mu_1-mu_0>delta) as the outcome, the candidate subgroups as covariates, and weights=abs((mu_1-mu_0) - delta). "otr:rf" follows the same approach but with weighted random forest, and also includes X in the regression. Regardless of the pooling approach, the key output is "trt_assign", a data-frame with the initial subgroups and the pooled subgroups (ex: dopt=1, patient should receive A=1, vs dopt=0, patient should receive A=0). |
delta |
Threshold for defining benefit vs non-benefitting patients. Only applicable for submod="otr", and if pooling is used (see "pool"). Default=">0". |
param |
Parameter estimation and inference function. Based on the discovered subgroups, estimate parameter estimates and correspond variability metrics. Options include "lm" (unadjusted linear regression), "dr" (doubly-robust estimator), "gcomp" (G-computation, average the patient-level estimates), "cox" (cox regression), and "rmst" (RMST based estimates as in survRMST package). Default for "gaussian", "binomial" is "dr", while default for "survival" is "cox". Currently only available for binary treatments or A=NULL. |
resample |
Resampling method for resample-based treatment effect estimates and variability metrics. Options include "Bootstrap" and "CV" (cross-validation). Default=NULL (No resampling). |
R |
Number of resamples (default=NULL; R=100 for Permutation/Bootstrap and R=5 for CV). This resamples the entire PRISM procedure. |
resample_pool |
For submod only, resampling method for pooling step. nly applicable if resample_submod="Bootstrap" and/or pool="trteff_boot". |
R_pool |
Number of resamples for resample_pool |
stratify |
Stratified resampling? Default="trt" (stratify by A). Other options include "sub" (stratify by the identified subgroups), "trt_sub" (stratify by A and the identified subgroups), and "no" (no stratification). |
combine |
Method of combining group-specific point-estimates. Options include "SS" (sample size weighting), and "maxZ" (see: Mehrotra and Marceau-West). This is used for pooling (ex: within dopt=1 groups, aggregate group-specific treatment estimates), and for calculating the overall population treatment effect estimate. |
alpha_ovrl |
Two-sided alpha level for overall population. Default=0.05 |
alpha_s |
Two-sided alpha level at subgroup level. Default=0.05 |
verbose.resamp |
Output iterations during resampling? Default=FALSE |
efficient |
If TRUE (default for PRISM), then models (filter, ple, submod) will store reduced set of outputs for faster speed. |
... |
Any additional parameters, not currently passed through. |
Details
submod_train currently fits a number of tree-based subgroup models, most of which aim to find subgroups with varying treatment effects (i.e. predictive variables). Let E(Y|A=1,X)-E(Y|A=0,X) = CATE(X) correspond to the estimated conditional average treatment effect. Current options include:
1. lmtree: Wrapper function for the function "lmtree" from the partykit package. Here, model-based partitioning (MOB) with an OLS loss function, Y~MOB_OLS(A,X), is used to identify prognostic and/or predictive variables. If the outcome Y is survival, then this outcome will first be transformed via log-rank scores (coin::logrank_trafo(Y)).
Default hyper-parameters are: hyper = list(alpha=0.05, maxdepth=4, parm=NULL, minsize=floor(dim(X)[1]*0.10)).
2. glmtree: Wrapper function for the function "glmtree" from the partykit package. Here, model-based partitioning (MOB) with GLM binomial + identity link loss function, (Y~MOB_GLM(A,X)), is used to identify prognostic and/or predictive variables.
Default hyper-parameters are: hyper = list(link="identity", alpha=0.05, maxdepth=4, parm=NULL, minsize=floor(dim(X)[1]*0.10)).
3. ctree / ctree_cate: Wrapper function for the function "ctree" from the partykit package. Here, conditional inference trees are used to identify either prognostic ("ctree"), Y~CTREE(X), or predictive variables, CATE(X) ~ CTREE(X).
Default hyper-parameters are: hyper=list(alpha=0.10, minbucket = floor(dim(X)[1]*0.10), maxdepth = 4).
4. rpart / rpart_cate: Recursive partitioning through the "rpart" R package. Here, recursive partitioning and regression trees are used to identify either prognostic ("rpart"), Y~rpart(X), or predictive variables ("rpart_cate"), CATE(X)~rpart(X).
Default hyper-parameters are: hyper=list(alpha=0.10, minbucket = floor(dim(X)[1]*0.10), maxdepth = 4).
5. mob_weib: Wrapper function for the function "mob" with weibull loss function using the partykit package. Here, model-based partitioning (MOB) with weibull loss (survival), (Y~MOB_WEIB(A,X)), is used to identify prognostic and/or predictive variables.
Default hyper-parameters are: hyper = list(alpha=0.10, maxdepth=4, parm=NULL, minsize=floor(dim(X)[1]*0.10)).
6. otr: Optimal treatment regime approach using "ctree". Based on CATE estimates and clinically meaningful threshold delta (ex: >0), fit I(CATE>delta)~CTREE(X) with weights=abs(CATE-delta).
Default hyper-parameters are: hyper=list(alpha=0.10, minbucket = floor(dim(X)[1]*0.10), maxdepth = 4, delta=">0").
Value
Trained subgroup model and subgroup predictions/estimates for train/test sets.
mod - trained subgroup model
Subgrps.train - Identified subgroups (training set)
Subgrps.test - Identified subgroups (test set)
pred.train - Predictions (training set)
pred.test - Predictions (test set)
Rules - Definitions for subgroups, if provided in fitted submod output.
References
Zeileis A, Hothorn T, Hornik K (2008). Model-Based Recursive Partitioning. Journal of Computational and Graphical Statistics, 17(2), 492–514.
Seibold H, Zeileis A, Hothorn T. Model-based recursive partitioning for subgroup analyses. Int J Biostat, 12 (2016), pp. 45-63
Hothorn T, Hornik K, Zeileis A (2006). Unbiased Recursive Partitioning: A Conditional Inference Framework. Journal of Computational and Graphical Statistics, 15(3), 651–674.
Zhao et al. (2012) Estimated individualized treatment rules using outcome weighted learning. Journal of the American Statistical Association, 107(409): 1106-1118.
Breiman L, Friedman JH, Olshen RA, and Stone CJ. (1984) Classification and Regression Trees. Wadsworth
See Also
Examples
library(StratifiedMedicine)
## Continuous ##
dat_ctns = generate_subgrp_data(family="gaussian")
Y = dat_ctns$Y
X = dat_ctns$X
A = dat_ctns$A
# Fit through submod_train wrapper #
mod1 = submod_train(Y=Y, A=A, X=X, Xtest=X, submod="submod_lmtree")
table(mod1$Subgrps.train)
plot(mod1$fit$mod)
mod1$trt_eff