tfd_joint_distribution_named_auto_batched {tfprobability} | R Documentation |
Joint distribution parameterized by named distribution-making functions.
Description
This class provides automatic vectorization and alternative semantics for
tfd_joint_distribution_named()
, which in many cases allows for
simplifications in the model specification.
Usage
tfd_joint_distribution_named_auto_batched(
model,
batch_ndims = 0,
use_vectorized_map = TRUE,
validate_args = FALSE,
name = NULL
)
Arguments
model |
A generator that yields a sequence of |
batch_ndims |
|
use_vectorized_map |
|
validate_args |
Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE. |
name |
name prefixed to Ops created by this class. |
Details
Automatic vectorization
Auto-vectorized variants of JointDistribution allow the user to avoid
explicitly annotating a model's vectorization semantics.
When using manually-vectorized joint distributions, each operation in the
model must account for the possibility of batch dimensions in Distributions
and their samples. By contrast, auto-vectorized models need only describe
a single sample from the joint distribution; any batch evaluation is
automated using tf$vectorized_map
as required. In many cases this
allows for significant simplications. For example, the following
manually-vectorized tfd_joint_distribution_named()
model:
model <- tfd_joint_distribution_sequential( list( x = tfd_normal(loc = 0, scale = tf$ones(3L)), y = tfd_normal(loc = 0, scale = 1), z = function(y, x) { tfd_normal(loc = x[reticulate::py_ellipsis(), 1:2] + y[reticulate::py_ellipsis(), tf$newaxis], scale = 1) } ) )
can be written in auto-vectorized form as
model <- tfd_joint_distribution_sequential_auto_batched( list( x = tfd_normal(loc = 0, scale = tf$ones(3L)), y = tfd_normal(loc = 0, scale = 1), z = function(y, x) {tfd_normal(loc = x[1:2] + y, scale = 1)} ) )
in which we were able to avoid explicitly accounting for batch dimensions
when indexing and slicing computed quantities in the third line.
Note: auto-vectorization is still experimental and some TensorFlow ops may
be unsupported. It can be disabled by setting use_vectorized_map=FALSE
.
Alternative batch semantics
This class also provides alternative semantics for specifying a batch of
independent (non-identical) joint distributions.
Instead of simply summing the log_prob
s of component distributions
(which may have different shapes), it first reduces the component log_prob
s
to ensure that jd$log_prob(jd$sample())
always returns a scalar, unless
batch_ndims
is explicitly set to a nonzero value (in which case the result
will have the corresponding tensor rank).
The essential changes are:
An
event
ofJointDistributionNamedAutoBatched
is the list of tensors produced by$sample()
; thus, theevent_shape
is the list containing the shapes of sampled tensors. These combine both the event and batch dimensions of the component distributions. By contrast, the event shape of a baseJointDistribution
s does not include batch dimensions of component distributions.The
batch_shape
is a global property of the entire model, rather than a per-component property as in baseJointDistribution
s. The global batch shape must be a prefix of the batch shapes of each component; the length of this prefix is specified by an optional argumentbatch_ndims
. Ifbatch_ndims
is not specified, the model has batch shape()
.#'
Value
a distribution instance.
See Also
For usage examples see e.g. tfd_sample()
, tfd_log_prob()
, tfd_mean()
.
Other distributions:
tfd_autoregressive()
,
tfd_batch_reshape()
,
tfd_bates()
,
tfd_bernoulli()
,
tfd_beta_binomial()
,
tfd_beta()
,
tfd_binomial()
,
tfd_categorical()
,
tfd_cauchy()
,
tfd_chi2()
,
tfd_chi()
,
tfd_cholesky_lkj()
,
tfd_continuous_bernoulli()
,
tfd_deterministic()
,
tfd_dirichlet_multinomial()
,
tfd_dirichlet()
,
tfd_empirical()
,
tfd_exp_gamma()
,
tfd_exp_inverse_gamma()
,
tfd_exponential()
,
tfd_gamma_gamma()
,
tfd_gamma()
,
tfd_gaussian_process_regression_model()
,
tfd_gaussian_process()
,
tfd_generalized_normal()
,
tfd_geometric()
,
tfd_gumbel()
,
tfd_half_cauchy()
,
tfd_half_normal()
,
tfd_hidden_markov_model()
,
tfd_horseshoe()
,
tfd_independent()
,
tfd_inverse_gamma()
,
tfd_inverse_gaussian()
,
tfd_johnson_s_u()
,
tfd_joint_distribution_named()
,
tfd_joint_distribution_sequential_auto_batched()
,
tfd_joint_distribution_sequential()
,
tfd_kumaraswamy()
,
tfd_laplace()
,
tfd_linear_gaussian_state_space_model()
,
tfd_lkj()
,
tfd_log_logistic()
,
tfd_log_normal()
,
tfd_logistic()
,
tfd_mixture_same_family()
,
tfd_mixture()
,
tfd_multinomial()
,
tfd_multivariate_normal_diag_plus_low_rank()
,
tfd_multivariate_normal_diag()
,
tfd_multivariate_normal_full_covariance()
,
tfd_multivariate_normal_linear_operator()
,
tfd_multivariate_normal_tri_l()
,
tfd_multivariate_student_t_linear_operator()
,
tfd_negative_binomial()
,
tfd_normal()
,
tfd_one_hot_categorical()
,
tfd_pareto()
,
tfd_pixel_cnn()
,
tfd_poisson_log_normal_quadrature_compound()
,
tfd_poisson()
,
tfd_power_spherical()
,
tfd_probit_bernoulli()
,
tfd_quantized()
,
tfd_relaxed_bernoulli()
,
tfd_relaxed_one_hot_categorical()
,
tfd_sample_distribution()
,
tfd_sinh_arcsinh()
,
tfd_skellam()
,
tfd_spherical_uniform()
,
tfd_student_t_process()
,
tfd_student_t()
,
tfd_transformed_distribution()
,
tfd_triangular()
,
tfd_truncated_cauchy()
,
tfd_truncated_normal()
,
tfd_uniform()
,
tfd_variational_gaussian_process()
,
tfd_vector_diffeomixture()
,
tfd_vector_exponential_diag()
,
tfd_vector_exponential_linear_operator()
,
tfd_vector_laplace_diag()
,
tfd_vector_laplace_linear_operator()
,
tfd_vector_sinh_arcsinh_diag()
,
tfd_von_mises_fisher()
,
tfd_von_mises()
,
tfd_weibull()
,
tfd_wishart_linear_operator()
,
tfd_wishart_tri_l()
,
tfd_wishart()
,
tfd_zipf()