predictFromMultipleChains {batchmix}R Documentation

Predict from multiple MCMC chains

Description

Applies a burn in to and finds a point estimate by combining multiple chains of “callMDI“.

Usage

predictFromMultipleChains(
  mcmc_outputs,
  burn,
  point_estimate_method = "median",
  chains_already_processed = FALSE
)

Arguments

mcmc_outputs

Output from “runMCMCChains“

burn

The number of MCMC samples to drop as part of a burn in.

point_estimate_method

Summary statistic used to define the point estimate. Must be “'mean'“ or “'median'“. “'median'“ is the default.

chains_already_processed

Logical indicating if the chains have already had a burn-in applied.

Value

A named list of quantities related to prediction/clustering:

* “allocation_probability“: List with an $(N x K)$ matrix if the model is semi-supervised. The point estimate of the allocation probabilities for each data point to each class.

* “prob“: $N$ vector of the point estimate of the probability of being allocated to the class with the highest probability.

* “pred“: $N$ vector of the predicted class for each sample. If the model is unsupervised then the “salso“ function from Dahl et al. (2021) is used on the sampled partitions using the default settings.

* “samples“: List of sampled allocations for each view. Columns correspond to items being clustered, rows to MCMC samples.

Examples


# Data dimensions
N <- 600
P <- 4
K <- 5
B <- 7

# Generating model parameters
mean_dist <- 2.25
batch_dist <- 0.3
group_means <- seq(1, K) * mean_dist
batch_shift <- rnorm(B, mean = batch_dist, sd = batch_dist)
std_dev <- rep(2, K)
batch_var <- rep(1.2, B)
group_weights <- rep(1 / K, K)
batch_weights <- rep(1 / B, B)
dfs <- c(4, 7, 15, 60, 120)

my_data <- generateBatchData(
  N,
  P,
  group_means,
  std_dev,
  batch_shift,
  batch_var,
  group_weights,
  batch_weights,
  type = "MVT",
  group_dfs = dfs
)


X <- my_data$observed_data

true_labels <- my_data$group_IDs
fixed <- my_data$fixed
batch_vec <- my_data$batch_IDs

alpha <- 1
initial_labels <- generateInitialLabels(alpha, K, fixed, true_labels)

# Sampling parameters
R <- 1000
thin <- 25
burn <- 100
n_chains <- 2

# Density choice
type <- "MVT"

# MCMC samples and BIC vector
mcmc_outputs <- runMCMCChains(
  X,
  n_chains,
  R,
  thin,
  batch_vec,
  type,
  initial_labels = initial_labels,
  fixed = fixed
)
ensemble_mod <- predictFromMultipleChains(mcmc_outputs, burn)


[Package batchmix version 2.2.1 Index]