vae_loss_correlated {ML2Pvae}R Documentation

A custom loss function for a VAE learning a multivariate normal distribution with a full covariance matrix

Description

A custom loss function for a VAE learning a multivariate normal distribution with a full covariance matrix

Usage

vae_loss_correlated(
  encoder,
  inv_skill_cov,
  det_skill_cov,
  skill_mean,
  kl_weight,
  rec_dim
)

Arguments

encoder

the encoder model of the VAE, used to obtain z_mean and z_log_cholesky from inputs

inv_skill_cov

a constant tensor matrix of the inverse of the covariance matrix being learned

det_skill_cov

a constant tensor scalar representing the determinant of the covariance matrix being learned

skill_mean

a constant tensor vector representing the means of the latent skills being learned

kl_weight

weight for the KL divergence term

rec_dim

the number of nodes in the input/output of the VAE

Value

returns a function whose parameters match keras loss format


[Package ML2Pvae version 1.0.0.1 Index]