| tfb_ffjord {tfprobability} | R Documentation |
Implements a continuous normalizing flow X->Y defined via an ODE.
Description
This bijector implements a continuous dynamics transformation parameterized by a differential equation, where initial and terminal conditions correspond to domain (X) and image (Y) i.e.
Usage
tfb_ffjord(
state_time_derivative_fn,
ode_solve_fn = NULL,
trace_augmentation_fn = tfp$bijectors$ffjord$trace_jacobian_hutchinson,
initial_time = 0,
final_time = 1,
validate_args = FALSE,
dtype = tf$float32,
name = "ffjord"
)
Arguments
state_time_derivative_fn |
|
ode_solve_fn |
|
trace_augmentation_fn |
|
initial_time |
Scalar float representing time to which the |
final_time |
Scalar float representing time to which the |
validate_args |
Logical, default FALSE. Whether to validate input with asserts. If validate_args is FALSE, and the inputs are invalid, correct behavior is not guaranteed. |
dtype |
|
name |
name prefixed to Ops created by this class. |
Details
d/dt[state(t)] = state_time_derivative_fn(t, state(t)) state(initial_time) = X state(final_time) = Y
For this transformation the value of log_det_jacobian follows another
differential equation, reducing it to computation of the trace of the jacobian
along the trajectory
state_time_derivative = state_time_derivative_fn(t, state(t)) d/dt[log_det_jac(t)] = Tr(jacobian(state_time_derivative, state(t)))
FFJORD constructor takes two functions ode_solve_fn and
trace_augmentation_fn arguments that customize integration of the
differential equation and trace estimation.
Differential equation integration is performed by a call to ode_solve_fn.
Custom ode_solve_fn must accept the following arguments:
ode_fn(time, state): Differential equation to be solved.
initial_time: Scalar float or floating Tensor representing the initial time.
initial_state: Floating Tensor representing the initial state.
solution_times: 1D floating Tensor of solution times.
And return a Tensor of shape [solution_times$shape, initial_state$shape]
representing state values evaluated at solution_times. In addition
ode_solve_fn must support nested structures. For more details see the
interface of tfp$math$ode$Solver$solve().
Trace estimation is computed simultaneously with state_time_derivative
using augmented_state_time_derivative_fn that is generated by
trace_augmentation_fn. trace_augmentation_fn takes
state_time_derivative_fn, state.shape and state.dtype arguments and
returns a augmented_state_time_derivative_fn callable that computes both
state_time_derivative and unreduced trace_estimation.
Custom ode_solve_fn and trace_augmentation_fn examples:
# custom_solver_fn: `function(f, t_initial, t_solutions, y_initial, ...)`
# ... : Additional arguments to pass to custom_solver_fn.
ode_solve_fn <- function(ode_fn, initial_time, initial_state, solution_times) {
custom_solver_fn(ode_fn, initial_time, solution_times, initial_state, ...)
}
ffjord <- tfb_ffjord(state_time_derivative_fn, ode_solve_fn = ode_solve_fn)
# state_time_derivative_fn: `function(time, state)`
# trace_jac_fn: `function(time, state)` unreduced jacobian trace function
trace_augmentation_fn <- function(ode_fn, state_shape, state_dtype) {
augmented_ode_fn <- function(time, state) {
list(ode_fn(time, state), trace_jac_fn(time, state))
}
augmented_ode_fn
}
ffjord <- tfb_ffjord(state_time_derivative_fn, trace_augmentation_fn = trace_augmentation_fn)
For more details on FFJORD and continous normalizing flows see Chen et al. (2018), Grathwol et al. (2018).
Value
a bijector instance.
References
Chen, T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural ordinary differential equations. In Advances in neural information processing systems (pp. 6571-6583)
See Also
For usage examples see tfb_forward(), tfb_inverse(), tfb_inverse_log_det_jacobian().
Other bijectors:
tfb_absolute_value(),
tfb_affine_linear_operator(),
tfb_affine_scalar(),
tfb_affine(),
tfb_ascending(),
tfb_batch_normalization(),
tfb_blockwise(),
tfb_chain(),
tfb_cholesky_outer_product(),
tfb_cholesky_to_inv_cholesky(),
tfb_correlation_cholesky(),
tfb_cumsum(),
tfb_discrete_cosine_transform(),
tfb_expm1(),
tfb_exp(),
tfb_fill_scale_tri_l(),
tfb_fill_triangular(),
tfb_glow(),
tfb_gompertz_cdf(),
tfb_gumbel_cdf(),
tfb_gumbel(),
tfb_identity(),
tfb_inline(),
tfb_invert(),
tfb_iterated_sigmoid_centered(),
tfb_kumaraswamy_cdf(),
tfb_kumaraswamy(),
tfb_lambert_w_tail(),
tfb_masked_autoregressive_default_template(),
tfb_masked_autoregressive_flow(),
tfb_masked_dense(),
tfb_matrix_inverse_tri_l(),
tfb_matvec_lu(),
tfb_normal_cdf(),
tfb_ordered(),
tfb_pad(),
tfb_permute(),
tfb_power_transform(),
tfb_rational_quadratic_spline(),
tfb_rayleigh_cdf(),
tfb_real_nvp_default_template(),
tfb_real_nvp(),
tfb_reciprocal(),
tfb_reshape(),
tfb_scale_matvec_diag(),
tfb_scale_matvec_linear_operator(),
tfb_scale_matvec_lu(),
tfb_scale_matvec_tri_l(),
tfb_scale_tri_l(),
tfb_scale(),
tfb_shifted_gompertz_cdf(),
tfb_shift(),
tfb_sigmoid(),
tfb_sinh_arcsinh(),
tfb_sinh(),
tfb_softmax_centered(),
tfb_softplus(),
tfb_softsign(),
tfb_split(),
tfb_square(),
tfb_tanh(),
tfb_transform_diagonal(),
tfb_transpose(),
tfb_weibull_cdf(),
tfb_weibull()