predict.wbart {BART} | R Documentation |
Predicting new observations with a previously fitted BART model
Description
BART is a Bayesian “sum-of-trees” model.
For a numeric response y
, we have
y = f(x) + \epsilon
,
where \epsilon \sim N(0,\sigma^2)
.
f
is the sum of many tree models.
The goal is to have very flexible inference for the uknown
function f
.
In the spirit of “ensemble models”, each tree is constrained by a prior to be a weak learner so that it contributes a small amount to the overall fit.
Usage
## S3 method for class 'wbart'
predict(object, newdata, mc.cores=1, openmp=(mc.cores.openmp()>0), ...)
Arguments
object |
|
newdata |
Matrix of covariates to predict |
mc.cores |
Number of threads to utilize. |
openmp |
Logical value dictating whether OpenMP is utilized for parallel
processing. Of course, this depends on whether OpenMP is available
on your system which, by default, is verified with |
... |
Other arguments which will be passed on to |
Details
BART is an Bayesian MCMC method.
At each MCMC interation, we produce a draw from the joint posterior
(f,\sigma) | (x,y)
in the numeric y
case
and just f
in the binary y
case.
Thus, unlike a lot of other modelling methods in R, we do not produce a single model object
from which fits and summaries may be extracted. The output consists of values
f^*(x)
(and \sigma^*
in the numeric case) where * denotes a particular draw.
The x
is either a row from the training data (x.train) or the test data (x.test).
Value
Returns a matrix of predictions corresponding to newdata
.
See Also
wbart
, mc.wbart
,
pwbart
, mc.pwbart
,
mc.cores.openmp
Examples
##simulate data (example from Friedman MARS paper)
f = function(x){
10*sin(pi*x[,1]*x[,2]) + 20*(x[,3]-.5)^2+10*x[,4]+5*x[,5]
}
sigma = 1.0 #y = f(x) + sigma*z , z~N(0,1)
n = 100 #number of observations
set.seed(99)
x=matrix(runif(n*10),n,10) #10 variables, only first 5 matter
y=f(x)
##test BART with token run to ensure installation works
set.seed(99)
post = wbart(x,y,nskip=5,ndpost=5)
x.test = matrix(runif(500*10),500,10)
## Not run:
##run BART
set.seed(99)
post = wbart(x,y)
x.test = matrix(runif(500*10),500,10)
pred = predict(post, x.test, mu=mean(y))
plot(apply(pred, 2, mean), f(x.test))
## End(Not run)