vi_csiszar_vimco {tfprobability} | R Documentation |
Use VIMCO to lower the variance of the gradient of csiszar_function(Avg(logu))
Description
This function generalizes VIMCO (Mnih and Rezende, 2016) to Csiszar f-Divergences.
Usage
vi_csiszar_vimco(
f,
p_log_prob,
q,
num_draws,
num_batch_draws = 1,
seed = NULL,
name = NULL
)
Arguments
f |
function representing a Csiszar-function in log-space. |
p_log_prob |
function representing the natural-log of the
probability under distribution |
q |
|
num_draws |
Integer scalar number of draws used to approximate the f-Divergence expectation. |
num_batch_draws |
Integer scalar number of draws used to approximate the f-Divergence expectation. |
seed |
|
name |
String prefixed to Ops created by this function. |
Details
Note: if q.reparameterization_type = tfd.FULLY_REPARAMETERIZED
,
consider using monte_carlo_csiszar_f_divergence
.
The VIMCO loss is:
vimco = f(Avg{logu[i] : i=0,...,m-1}) where, logu[i] = log( p(x, h[i]) / q(h[i] | x) ) h[i] iid~ q(H | x)
Interestingly, the VIMCO gradient is not the naive gradient of vimco
.
Rather, it is characterized by:
grad[vimco] - variance_reducing_term
where,
variance_reducing_term = Sum{ grad[log q(h[i] | x)] * (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) #' : i=0, ..., m-1 } h[j;i] = u[j] for j!=i, GeometricAverage{ u[k] : k!=i} for j==i
(We omitted stop_gradient
for brevity. See implementation for more details.)
The Avg{h[j;i] : j}
term is a kind of "swap-out average" where the i
-th
element has been replaced by the leave-i
-out Geometric-average.
This implementation prefers numerical precision over efficiency, i.e.,
O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))
.
(The constant may be fairly large, perhaps around 12.)
Value
vimco The Csiszar f-Divergence generalized VIMCO objective
References
See Also
Other vi-functions:
vi_amari_alpha()
,
vi_arithmetic_geometric()
,
vi_chi_square()
,
vi_dual_csiszar_function()
,
vi_fit_surrogate_posterior()
,
vi_jeffreys()
,
vi_jensen_shannon()
,
vi_kl_forward()
,
vi_kl_reverse()
,
vi_log1p_abs()
,
vi_modified_gan()
,
vi_monte_carlo_variational_loss()
,
vi_pearson()
,
vi_squared_hellinger()
,
vi_symmetrized_csiszar_function()