softbart {SoftBart}R Documentation

Fits the SoftBart model

Description

Runs the Markov chain for the semiparametric Gaussian model

Y = r(X) + \epsilon

and collects the output, where r(x) is modeled using a soft BART model.

Usage

softbart(X, Y, X_test, hypers = NULL, opts = Opts(), verbose = TRUE)

Arguments

X

A matrix of training data covariates.

Y

A vector of training data responses.

X_test

A matrix of test data covariates

hypers

A ;ist of hyperparameter values obtained from Hypers function

opts

A list of MCMC chain settings obtained from Opts function

verbose

If TRUE, progress of the chain will be printed to the console.

Value

Returns a list with the following components:

Examples


## NOTE: SET NUMBER OF BURN IN AND SAMPLE ITERATIONS HIGHER IN PRACTICE

num_burn <- 10 ## Should be ~ 5000
num_save <- 10 ## Should be ~ 5000

set.seed(1234)
f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 + 
  10 * x[,4] + 5 * x[,5]

gen_data <- function(n_train, n_test, P, sigma) {
  X <- matrix(runif(n_train * P), nrow = n_train)
  mu <- f_fried(X)
  X_test <- matrix(runif(n_test * P), nrow = n_test)
  mu_test <- f_fried(X_test)
  Y <- mu + sigma * rnorm(n_train)
  Y_test <- mu_test + sigma * rnorm(n_test)
  
  return(list(X = X, Y = Y, mu = mu, X_test = X_test, Y_test = Y_test, mu_test = mu_test))
}

## Simiulate dataset
sim_data <- gen_data(250, 100, 1000, 1)

## Fit the model
fit <- softbart(X = sim_data$X, Y = sim_data$Y, X_test = sim_data$X_test, 
                hypers = Hypers(sim_data$X, sim_data$Y, num_tree = 50, temperature = 1),
                opts = Opts(num_burn = num_burn, num_save = num_save, update_tau = TRUE))

## Plot the fit (note: interval estimates are not prediction intervals, 
## so they do not cover the predictions at the nominal rate)
plot(fit)

## Look at posterior model inclusion probabilities for each predictor. 

plot(posterior_probs(fit)[["post_probs"]], 
     col = ifelse(posterior_probs(fit)[["post_probs"]] > 0.5, scales::muted("blue"), 
                  scales::muted("green")), 
     pch = 20)


rmse(fit$y_hat_test_mean, sim_data$mu_test)
rmse(fit$y_hat_train_mean, sim_data$mu)


[Package SoftBart version 1.0.1 Index]