partial_dependence_regression {SoftBart} | R Documentation |
Partial Dependence Function for SoftBART Regression
Description
Computes the partial dependence function for a given covariate at a given set of covariate values.
Usage
partial_dependence_regression(fit, test_data, var_str, grid)
Arguments
fit |
A fitted model of type |
test_data |
A data set used to form the baseline distribution of covariates for the partial dependence function. |
var_str |
A string giving the variable name of the predictor to compute the partial dependence function for. |
grid |
The values of the predictor to compute the partial dependence function at. |
Value
Returns a list with the following components:
-
pred_df
: a data.frame containing columns for a MCMC iteration ID (sample
), the value on the grid, and the partial dependence function value. -
mu
: a matrix containing the same information aspred_df
, with the rows corresponding to iterations and columns corresponding to grid values. -
grid
: the grid used as input.
Examples
## NOTE: SET NUMBER OF BURN IN AND SAMPLE ITERATIONS HIGHER IN PRACTICE
num_burn <- 10 ## Should be ~ 5000
num_save <- 10 ## Should be ~ 5000
set.seed(1234)
f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
gen_data <- function(n_train, n_test, P, sigma) {
X <- matrix(runif(n_train * P), nrow = n_train)
mu <- f_fried(X)
X_test <- matrix(runif(n_test * P), nrow = n_test)
mu_test <- f_fried(X_test)
Y <- mu + sigma * rnorm(n_train)
Y_test <- mu + sigma * rnorm(n_test)
return(list(X = X, Y = Y, mu = mu, X_test = X_test, Y_test = Y_test,
mu_test = mu_test))
}
## Simiulate dataset
sim_data <- gen_data(250, 250, 10, 1)
df <- data.frame(X = sim_data$X, Y = sim_data$Y)
df_test <- data.frame(X = sim_data$X_test, Y = sim_data$Y_test)
## Fit the model
opts <- Opts(num_burn = num_burn, num_save = num_save)
fitted_reg <- softbart_regression(Y ~ ., df, df_test, opts = opts)
## Compute PDP and plot
grid <- seq(from = 0, to = 1, length = 10)
pdp_x4 <- partial_dependence_regression(fitted_reg, df_test, "X.4", grid)
plot(pdp_x4$grid, colMeans(pdp_x4$mu))