predictFromMultipleChains {batchmix} | R Documentation |
Predict from multiple MCMC chains
Description
Applies a burn in to and finds a point estimate by combining multiple chains of “callMDI“.
Usage
predictFromMultipleChains(
mcmc_outputs,
burn,
point_estimate_method = "median",
chains_already_processed = FALSE
)
Arguments
mcmc_outputs |
Output from “runMCMCChains“ |
burn |
The number of MCMC samples to drop as part of a burn in. |
point_estimate_method |
Summary statistic used to define the point estimate. Must be “'mean'“ or “'median'“. “'median'“ is the default. |
chains_already_processed |
Logical indicating if the chains have already had a burn-in applied. |
Value
A named list of quantities related to prediction/clustering:
* “allocation_probability“: List with an $(N x K)$ matrix if the model is semi-supervised. The point estimate of the allocation probabilities for each data point to each class.
* “prob“: $N$ vector of the point estimate of the probability of being allocated to the class with the highest probability.
* “pred“: $N$ vector of the predicted class for each sample. If the model is unsupervised then the “salso“ function from Dahl et al. (2021) is used on the sampled partitions using the default settings.
* “samples“: List of sampled allocations for each view. Columns correspond to items being clustered, rows to MCMC samples.
Examples
# Data dimensions
N <- 600
P <- 4
K <- 5
B <- 7
# Generating model parameters
mean_dist <- 2.25
batch_dist <- 0.3
group_means <- seq(1, K) * mean_dist
batch_shift <- rnorm(B, mean = batch_dist, sd = batch_dist)
std_dev <- rep(2, K)
batch_var <- rep(1.2, B)
group_weights <- rep(1 / K, K)
batch_weights <- rep(1 / B, B)
dfs <- c(4, 7, 15, 60, 120)
my_data <- generateBatchData(
N,
P,
group_means,
std_dev,
batch_shift,
batch_var,
group_weights,
batch_weights,
type = "MVT",
group_dfs = dfs
)
X <- my_data$observed_data
true_labels <- my_data$group_IDs
fixed <- my_data$fixed
batch_vec <- my_data$batch_IDs
alpha <- 1
initial_labels <- generateInitialLabels(alpha, K, fixed, true_labels)
# Sampling parameters
R <- 1000
thin <- 25
burn <- 100
n_chains <- 2
# Density choice
type <- "MVT"
# MCMC samples and BIC vector
mcmc_outputs <- runMCMCChains(
X,
n_chains,
R,
thin,
batch_vec,
type,
initial_labels = initial_labels,
fixed = fixed
)
ensemble_mod <- predictFromMultipleChains(mcmc_outputs, burn)