PRISM: Patient Response Identifier for Stratified Medicine


PRISM algorithm. Given a data-set of (Y, A, X) (Outcome, treatment, covariates), the PRISM identifies potential subgroups along with point-estimate and variability metrics; with and without resampling (bootstrap or cross-validation based). This four step procedure (filter, ple, submod, param) is flexible and accepts user-inputs at each step.


The outcome variable. Must be numeric or survival (ex; Surv(time,cens) )


Treatment variable. (Default supports binary treatment, either numeric or factor). "ple_train" accomodates >2 along with binary treatments.


Covariate space.


Test set. Default is NULL (no test predictions). Variable types should match X.


Patient-level estimates in training set (see ple_train). Default=NULL


Outcome type. Options include "gaussion" (default), "binomial", and "survival".


Filter model to determine variables that are likely associated with the outcome and/or treatment. Outputs a potential reduce list of varia where has potentially less variables than X. Default is "glmnet" (elastic net). Other options include "ranger" (random forest based variable importance with p-values). See filter_train for more details. "None" uses no filter.


Base-learner used to estimate patient-level equantities, such as the conditional average treatment effect (CATE), E(Y|A=1,X)-E(Y|A=0, X) = CATE(X). Default is random based based through "ranger". "None" uses no ple. See below for details on estimating the treatment contrasts.


Subgroup identification model function. Options include tree-methods that target the treatment by variable interaction directly ("lmtree", "glmtree", "mob_weib"), regress the CATE ("rpart_cate", "ctree_cate"), and target prognostic variables ("rpart", "ctree"). Default for family="gaussian" is "lmtree" (MOB with OLS loss). For "binomial" the default is "glmtree" (MOB with binomial loss). Default for "survival" is "lmtree" (log-rank transformation on survival outcomes and then fit MOB-OLS). "None" uses no submod. Currently only available for binary treatments or A=NULL.


Parameter estimation and inference function. Based on the discovered subgroups, estimate parameter estimates and correspond variability metrics. Options include "lm" (unadjusted linear regression), "dr" (doubly-robust estimator), "gcomp" (G-computation, average the patient-level estimates), "cox" (cox regression), and "rmst" (RMST based estimates as in survRMST package). Default for "gaussian", "binomial" is "dr", while default for "survival" is "cox". Currently only available for binary treatments or A=NULL.


Using the ple model as a base learner, meta-learners can be used for estimating patient-level treatment differences. Options include "T-learner" (treatment specific models), "S-learner" (single model), and "X-learner". For family="gaussian" & "binomial", the default is "X-learner", which uses a two-stage regression approach (See Kunzel et al 2019). For "survival", the default is "T-learner". "X-learner" is currently not supported for survival outcomes.


Whether to pool the initial identified subgroups (ex: tree nodes). Default = "no". Other options include "trteff" or "trteff_boot" (check if naive or bootstrap treatment estimate is beyond clinical meaningful threshold delta, ex: trteff_boot > 0), and optimal treatment regime (OTR) pooling, "otr:logistic", "otr:rf". "otr:logistic" fits weighted logistic regression with I(mu_1-mu_0>delta) as the outcome, the candidate subgroups as covariates, and weights=abs((mu_1-mu_0) - delta). "otr:rf" follows the same approach but with weighted random forest, and also includes X in the regression. Regardless of the pooling approach, the key output is "trt_assign", a data-frame with the initial subgroups and the pooled subgroups (ex: dopt=1, patient should receive A=1, vs dopt=0, patient should receive A=0).


Threshold for defining benefit vs non-benefitting patients. Only applicable for submod="otr", and if pooling is used (see "pool"). Default=">0".


Propensity score estimation, P(A=a|X). Default=FALSE which use the marginal estimates, P(A=a) (applicable for RCT data). If TRUE, will use the "ple" base learner to estimate P(A=a|X).


Method of combining group-specific point-estimates. Options include "SS" (sample size weighting), and "maxZ" (see: Mehrotra and Marceau-West). This is used for pooling (ex: within dopt=1 groups, aggregate group-specific treatment estimates), and for calculating the overall population treatment effect estimate.


Two-sided alpha level for overall population. Default=0.05


Two-sided alpha level at subgroup level. Default=0.05


Hyper-parameters for the filter function (must be list). Default is NULL.


Hyper-parameters for the PLE function (must be list). Default is NULL.


Hyper-parameters for the submod function (must be list). Default is NULL.


Resampling method for resample-based treatment effect estimates and variability metrics. Options include "Bootstrap" and "CV" (cross-validation). Default=NULL (No resampling).


Stratified resampling? Default="trt" (stratify by A). Other options include "sub" (stratify by the identified subgroups), "trt_sub" (stratify by A and the identified subgroups), and "no" (no stratification).


Number of resamples (default=NULL; R=100 for Permutation/Bootstrap and R=5 for CV). This resamples the entire PRISM procedure.


For submod only, resampling method for treatment effect estimates. Options include "Bootstrap" or NULL (no resampling).


Number of resamples for resample_submod


For submod only, resampling method for pooling step. nly applicable if resample_submod="Bootstrap" and/or pool="trteff_boot".


Number of resamples for resample_pool


Bootstrap calibration for nominal alpha (Loh et al 2016). Default=FALSE. For TRUE, outputs the calibrated alpha level and calibrated CIs for the overall population and subgroups. Not applicable for permutation or CV resampling.


Grid of alpha values for calibration. Default=NULL, which uses seq(alpha/1000,alpha,by=0.005) for alpha_ovrl/alpha_s.


Filter function during re-sampling. Default=NULL (uses "filter"). If "None", the "filter" model is not trained in each resample, and instead use filtered variables from the observed data "filter" step. (less computationally expensive).


Ple function during re-sampling. Default=NULL (uses "ple"). If "None", the "ple" model is not training in each resample, and instead the original model estimates are resampled (less computationally expensive).


Detail progress of PRISM? Default=TRUE


Output iterations during resampling? Default=FALSE


Seed for PRISM run (Default=777)


If TRUE (default for PRISM), then models (filter, ple, submod) will store reduced set of outputs for faster speed.


PRISM is a general framework with five key steps:

0. Estimand: Determine the question of interest (ex: mean treatment difference)

1. Filter (filter): Reduce covariate space by removing noise covariates. Options include elastic net ("glmnet") and random forest variable importance ("ranger").

2. Patient-Level Estimates (ple): Estimate counterfactual patient-level quantities, for example, the conditional average treatment effect (CATE), E(Y|A=1,X)-E(Y|A=0,X). This calls the "ple_train" function, and follows the framework of Kunzel et al 2019. Base-learners include random forest ("ranger"), BART ("bart"), elastic net ("glmnet"), and linear models (LM, GLM, or Cox regression). Meta-learners include the "S-Learner" (single model), "T-learner" (treatment specific models), and "X-learner" (2-stage approach).

3. Subgroup Model (submod): Currently uses tree-based methods to identify predictive and/or prognostic subgroups. Options include MOB OLS ("lmtree"), MOB GLM ("glmtree"), MOB Weibull ("mob_weib"), conditional inference trees ("ctree", Y~ctree(X); "ctree_cate", CATE~ctree(X)), and recursive partitioning and regression trees ("rpart", Y~rpart(X); "rpart_cate", CATE~rpart(X)), and optimal treatment regimes ("otr").

4. Treatment Effect Estimation (param): For the overall population and the discovered subgroups (if any), obtain treatment effect point-estimates and variability metrics. Options include: cox regression ("cox"), double robust estimator ("dr"), linear regression ("lm"), average of patient-level estimates ("gcomp"), and restricted mean survival time ("rmst").

Steps 1-4 also support user-specific models. If treatment is provided (A!=NULL), the default settings are as follows:

Y is continuous (family="gaussian"): Elastic Net Filter ==> X-learner with random forest ==> MOB (OLS) ==> Double Robust estimator

Y is binary (family="binomial"): Elastic Net Filter ==> X-learner with random forest ==> MOB (GLM) ==> Double Robust estimator

Y is right-censored (family="survival"): Elastic Net Filter ==> T-learner with random forest ==> MOB (Weibull) ==> Cox regression

If treatment is not provided (A=NULL), the default settings are as follows:

Y is continuous (family="gaussian"): Elastic Net Filter ==> Random Forest ==> ctree ==> linear regression

Y is binary (family="binomial"): Elastic Net Filter ==> Random Forest ==> ctree ==> linear regression

Y is right-censored (family="survival"): Elastic Net Filter ==> Survival Random Forest ==> ctree ==> RMST


Trained PRISM object. Includes filter, ple, submod, and param outputs.


## Load library ##

## Examples: Continuous Outcome ##

dat_ctns = generate_subgrp_data(family="gaussian")
Y = dat_ctns$Y
X = dat_ctns$X
A = dat_ctns$A

# Run Default: glmnet, ranger (X-learner), lmtree, dr #
res0 = PRISM(Y=Y, A=A, X=X)

res1 = PRISM(Y=Y, A=A, X=X, filter="None")

# Search for Prognostic Only (omit A from function) #

res3 = PRISM(Y=Y, X=X)

## With bootstrap (No filtering) ##

  res_boot = PRISM(Y=Y, A=A, X=X, resample = "Bootstrap", R=50, verbose.resamp = TRUE)
  # Plot of distributions and P(est>0) #
  plot(res_boot, type="resample", estimand = "E(Y|A=1)-E(Y|A=0)")+
  geom_vline(xintercept = 0)
  aggregate(I(est>0)~Subgrps, data=res_boot$resamp_dist, FUN="mean")

## Examples: Binary Outcome ##

dat_bin = generate_subgrp_data(family="binomial")
Y = dat_bin$Y
X = dat_bin$X
A = dat_bin$A

# Run Default: glmnet, ranger, glmtree, dr #
res0 = PRISM(Y=Y, A=A, X=X)


# Survival Data ##

  require(; require(coin)
  data("GBSG2", package = "")
  surv.dat = GBSG2
  # Design Matrices ###
  Y = with(surv.dat, Surv(time, cens))
  X = surv.dat[,!(colnames(surv.dat) %in% c("time", "cens")) ]
  A = rbinom( n = dim(X)[1], size=1, prob=0.5  )

  # PRISM: glmnet ==> Random Forest to estimate Treatment-Specific RMST
  # ==> MOB (Weibull) ==> Cox for HRs#
  res_weib = PRISM(Y=Y, A=A, X=X)
  plot(res_weib, type="PLE:waterfall")
  #PRISM: glmnet ==> Random Forest to estimate Treatment-Specific RMST
  #RPART_CATE: Regress RMST on RPART for subgroups #
  res_cate = PRISM(Y=Y, A=A, X=X, submod="rpart_cate")

  # PRISM: ENET ==> CTREE ==> Cox; with bootstrap #
  res_ctree1 = PRISM(Y=Y, A=A, X=X, ple="None", submod = "ctree",
                     resample="Bootstrap", R=50, verbose.resamp = TRUE)
  plot(res_ctree1, type="resample", estimand="HR(A=1 vs A=0)")+geom_vline(xintercept = 1)
  aggregate(I(est<0)~Subgrps, data=res_ctree1$resamp_dist, FUN="mean")

