gsoftbart_regression {SoftBart}R Documentation

General SoftBart Regression

Description

Fits the general (Soft) BART (GBART) model, which combines the BART model with a linear predictor. That is, it fits the semiparametric Gaussian regression model

Y = r(X) + Z^\top \beta + \epsilon

where the function r(x) is modeled using a BART ensemble.

Usage

gsoftbart_regression(
  formula,
  linear_formula,
  data,
  test_data,
  num_tree = 20,
  k = 2,
  hypers = NULL,
  opts = NULL,
  remove_intercept = TRUE,
  verbose = TRUE,
  warn = TRUE
)

Arguments

formula

A model formula with a numeric variable on the left-hand-side and non-linear predictors on the right-hand-side.

linear_formula

A model formula with the linear variables on the right-hand-side (left-hand-side is not used).

data

A data frame consisting of the training data.

test_data

A data frame consisting of the testing data.

num_tree

The number of trees used in the ensemble.

k

Determines the standard deviation of the leaf node parameters, which is given by 3 / k / sqrt(num_tree).

hypers

A list of hyperparameters constructed from the Hypers() function (num_tree, k, and sigma_mu are overridden by this function).

opts

A list of options for running the chain constructed from the Opts() function (update_sigma is overridden by this function).

remove_intercept

If TRUE then any intercept term in the linear formula will be removed, with the overall location of the outcome captured by the nonparametric function.

verbose

If TRUE, progress of the chain will be printed to the console.

warn

If TRUE, remind the user that they probably don't want the linear predictors to be included in the formula for the nonlinear part.

Value

Returns a list with the following components

Examples

## NOTE: SET NUMBER OF BURN IN AND SAMPLE ITERATIONS HIGHER IN PRACTICE

num_burn <- 10 ## Should be ~ 5000
num_save <- 10 ## Should be ~ 5000

set.seed(1234)
f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
  10 * x[,4] + 5 * x[,5]

gen_data <- function(n_train, n_test, P, sigma) {
  X <- matrix(runif(n_train * P), nrow = n_train)
  mu <- f_fried(X)
  X_test <- matrix(runif(n_test * P), nrow = n_test)
  mu_test <- f_fried(X_test)
  Y <- mu + sigma * rnorm(n_train)
  Y_test <- mu + sigma * rnorm(n_test)
  
  return(list(X = X, Y = Y, mu = mu, X_test = X_test, Y_test = Y_test,
              mu_test = mu_test))
}

## Simiulate dataset
sim_data <- gen_data(250, 250, 100, 1)

df <- data.frame(X = sim_data$X, Y = sim_data$Y)
df_test <- data.frame(X = sim_data$X_test, Y = sim_data$Y_test)

## Fit the model

opts <- Opts(num_burn = num_burn, num_save = num_save)
fitted_reg <- gsoftbart_regression(Y ~ . - X.4 - X.5, ~ X.4 + X.5, df, df_test, opts = opts)

## Plot results

plot(colMeans(fitted_reg$mu_test), sim_data$mu_test)
abline(a = 0, b = 1)
plot(fitted_reg$beta[,1])
plot(fitted_reg$beta[,2])

[Package SoftBart version 1.0.1 Index]