misclustering_error {mtlgmm}R Documentation

Calculate the misclustering error given the predicted cluster labels.

Description

Calculate the misclustering error given the predicted cluster labels.

Usage

misclustering_error(y_pred, y_test, type = c("max", "all", "avg"))

Arguments

y_pred

predicted cluster labels

y_test

true cluster labels

type

which type of the misclustering error rate to return. Can be either "max", "all", or "avg". Default: "max".

  • max: maximum of misclustering error rates on all tasks

  • all: a vector of misclustering error rates on each tasks

  • avg: average of misclustering error rates on all tasks

Value

Depends on type.

References

Tian, Y., Weng, H., & Feng, Y. (2022). Unsupervised Multi-task and Transfer Learning on Gaussian Mixture Models. arXiv preprint arXiv:2209.15224.

See Also

mtlgmm, tlgmm, data_generation, predict_gmm, initialize, alignment, alignment_swap, estimation_error.

Examples

set.seed(23, kind = "L'Ecuyer-CMRG")
## Consider a 5-task multi-task learning problem in the setting "MTL-1"
data_list <- data_generation(K = 5, outlier_K = 1, simulation_no = "MTL-1", h_w = 0.1,
h_mu = 1, n = 100)  # generate the data
x_train <- sapply(1:length(data_list$data$x), function(k){
  data_list$data$x[[k]][1:50,]
}, simplify = FALSE)
x_test <- sapply(1:length(data_list$data$x), function(k){
  data_list$data$x[[k]][-(1:50),]
}, simplify = FALSE)
y_test <- sapply(1:length(data_list$data$x), function(k){
  data_list$data$y[[k]][-(1:50)]
}, simplify = FALSE)

fit <- mtlgmm(x = x_train, C1_w = 0.05, C1_mu = 0.2, C1_beta = 0.2,
C2_w = 0.05, C2_mu = 0.2, C2_beta = 0.2, kappa = 1/3, initial_method = "EM",
trim = 0.1, lambda_choice = "fixed", step_size = "lipschitz")

y_pred <- sapply(1:length(data_list$data$x), function(i){
predict_gmm(w = fit$w[i], mu1 = fit$mu1[, i], mu2 = fit$mu2[, i],
beta = fit$beta[, i], newx = x_test[[i]])
}, simplify = FALSE)
misclustering_error(y_pred[-data_list$data$outlier_index],
y_test[-data_list$data$outlier_index], type = "max")

[Package mtlgmm version 0.1.0 Index]