mc.crisk2.pwbart {BART}R Documentation

Predicting new observations with a previously fitted BART model

Description

BART is a Bayesian “sum-of-trees” model.
For a numeric response yy, we have y=f(x)+ϵy = f(x) + \epsilon, where ϵN(0,σ2)\epsilon \sim N(0,\sigma^2).

ff is the sum of many tree models. The goal is to have very flexible inference for the uknown function ff.

In the spirit of “ensemble models”, each tree is constrained by a prior to be a weak learner so that it contributes a small amount to the overall fit.

Usage

mc.crisk2.pwbart( x.test, x.test2,
                 treedraws, treedraws2,
                 binaryOffset=0, binaryOffset2=0,
                 mc.cores=2L, type='pbart',
                 transposed=FALSE, nice=19L
               )

Arguments

x.test

Matrix of covariates to predict yy for cause 1.

x.test2

Matrix of covariates to predict yy for cause 2.

treedraws

$treedraws for cause 1.

treedraws2

$treedraws for cause 2.

binaryOffset

Mean to add on to yy prediction for cause 1.

binaryOffset2

Mean to add on to yy prediction for cause 2.

mc.cores

Number of threads to utilize.

type

Whether to employ Albert-Chib, 'pbart', or Holmes-Held, 'lbart'.

transposed

When running pwbart or mc.pwbart in parallel, it is more memory-efficient to transpose x.test prior to calling the internal versions of these functions.

nice

Set the job niceness. The default niceness is 19: niceness goes from 0 (highest) to 19 (lowest).

Details

BART is an Bayesian MCMC method. At each MCMC interation, we produce a draw from the joint posterior (f,σ)(x,y)(f,\sigma) | (x,y) in the numeric yy case and just ff in the binary yy case.

Thus, unlike a lot of other modelling methods in R, we do not produce a single model object from which fits and summaries may be extracted. The output consists of values f(x)f^*(x) (and σ\sigma^* in the numeric case) where * denotes a particular draw. The xx is either a row from the training data (x.train) or the test data (x.test).

Value

Returns an object of type crisk2bart which is essentially a list with components:

yhat.test

A matrix with ndpost rows and nrow(x.test) columns. Each row corresponds to a draw ff^* from the posterior of ff and each column corresponds to a row of x.train. The (i,j)(i,j) value is f(x)f^*(x) for the ithi^{th} kept draw of ff and the jthj^{th} row of x.train.
Burn-in is dropped.

surv.test

test data fits for survival probability.

surv.test.mean

mean of surv.test over the posterior samples.

prob.test

The probability of suffering cause 1 which is occasionally useful, e.g., in calculating the concordance.

prob.test2

The probability of suffering cause 2 which is occasionally useful, e.g., in calculating the concordance.

cif.test

The cumulative incidence function of cause 1, F1(t,x)F_1(t, x), where x's are the rows of the test data.

cif.test2

The cumulative incidence function of cause 2, F2(t,x)F_2(t, x), where x's are the rows of the test data.

yhat.test.mean

test data fits = mean of yhat.test columns.

cif.test.mean

mean of cif.test columns for cause 1.

cif.test2.mean

mean of cif.test2 columns for cause 2.

See Also

pwbart, crisk2.bart, mc.crisk2.bart

Examples


data(transplant)

delta <- (as.numeric(transplant$event)-1)
## recode so that delta=1 is cause of interest; delta=2 otherwise
delta[delta==1] <- 4
delta[delta==2] <- 1
delta[delta>1] <- 2
table(delta, transplant$event)

times <- pmax(1, ceiling(transplant$futime/7)) ## weeks
##times <- pmax(1, ceiling(transplant$futime/30.5)) ## months
table(times)

typeO <- 1*(transplant$abo=='O')
typeA <- 1*(transplant$abo=='A')
typeB <- 1*(transplant$abo=='B')
typeAB <- 1*(transplant$abo=='AB')
table(typeA, typeO)

x.train <- cbind(typeO, typeA, typeB, typeAB)

x.test <- cbind(1, 0, 0, 0)
dimnames(x.test)[[2]] <- dimnames(x.train)[[2]]

## parallel::mcparallel/mccollect do not exist on windows
if(.Platform$OS.type=='unix') {
##test BART with token run to ensure installation works
        post <- mc.crisk2.bart(x.train=x.train, times=times, delta=delta,
                               seed=99, mc.cores=2, nskip=5, ndpost=5,
                               keepevery=1)

        pre <- surv.pre.bart(x.train=x.train, x.test=x.test,
                             times=times, delta=delta)

        K <- post$K

        pred <- mc.crisk2.pwbart(pre$tx.test, pre$tx.test,
                                post$treedraws, post$treedraws2,
                                post$binaryOffset, post$binaryOffset2)
}

## Not run: 

## run one long MCMC chain in one process
## set.seed(99)
## post <- crisk2.bart(x.train=x.train, times=times, delta=delta, x.test=x.test)

## in the interest of time, consider speeding it up by parallel processing
## run "mc.cores" number of shorter MCMC chains in parallel processes
post <- mc.crisk2.bart(x.train=x.train,
                       times=times, delta=delta,
                       x.test=x.test, seed=99, mc.cores=8)

check <- mc.crisk2.pwbart(post$tx.test, post$tx.test,
                          post$treedraws, post$treedraws2,
                          post$binaryOffset,
                          post$binaryOffset2, mc.cores=8)
## check <- predict(post, newdata=post$tx.test, newdata2=post$tx.test2,
##                  mc.cores=8)

print(c(post$surv.test.mean[1], check$surv.test.mean[1],
        post$surv.test.mean[1]-check$surv.test.mean[1]), digits=22)

print(all(round(post$surv.test.mean, digits=9)==
    round(check$surv.test.mean, digits=9)))

print(c(post$cif.test.mean[1], check$cif.test.mean[1],
        post$cif.test.mean[1]-check$cif.test.mean[1]), digits=22)

print(all(round(post$cif.test.mean, digits=9)==
    round(check$cif.test.mean, digits=9)))

print(c(post$cif.test2.mean[1], check$cif.test2.mean[1],
        post$cif.test2.mean[1]-check$cif.test2.mean[1]), digits=22)

print(all(round(post$cif.test2.mean, digits=9)==
    round(check$cif.test2.mean, digits=9)))



## End(Not run)

[Package BART version 2.9.9 Index]