vae_loss_correlated {ML2Pvae} | R Documentation |
A custom loss function for a VAE learning a multivariate normal distribution with a full covariance matrix
Description
A custom loss function for a VAE learning a multivariate normal distribution with a full covariance matrix
Usage
vae_loss_correlated(
encoder,
inv_skill_cov,
det_skill_cov,
skill_mean,
kl_weight,
rec_dim
)
Arguments
encoder |
the encoder model of the VAE, used to obtain z_mean and z_log_cholesky from inputs |
inv_skill_cov |
a constant tensor matrix of the inverse of the covariance matrix being learned |
det_skill_cov |
a constant tensor scalar representing the determinant of the covariance matrix being learned |
skill_mean |
a constant tensor vector representing the means of the latent skills being learned |
kl_weight |
weight for the KL divergence term |
rec_dim |
the number of nodes in the input/output of the VAE |
Value
returns a function whose parameters match keras loss format
[Package ML2Pvae version 1.0.0.1 Index]