random_categorical {keras3} | R Documentation |
Draws samples from a categorical distribution.
Description
This function takes as input logits
, a 2-D input tensor with shape
(batch_size, num_classes). Each row of the input represents a categorical
distribution, with each column index containing the log-probability for a
given class.
The function will output a 2-D tensor with shape (batch_size, num_samples),
where each row contains samples from the corresponding row in logits
.
Each column index contains an independent samples drawn from the input
distribution.
Usage
random_categorical(logits, num_samples, dtype = "int32", seed = NULL)
Arguments
logits |
2-D Tensor with shape (batch_size, num_classes). Each row should define a categorical distibution with the unnormalized log-probabilities for all classes. |
num_samples |
Int, the number of independent samples to draw for each row of the input. This will be the second dimension of the output tensor's shape. |
dtype |
Optional dtype of the output tensor. |
seed |
An R integer or instance of
|
Value
A 2-D tensor with (batch_size, num_samples).
See Also
Other random:
random_beta()
random_binomial()
random_dropout()
random_gamma()
random_integer()
random_normal()
random_seed_generator()
random_shuffle()
random_truncated_normal()
random_uniform()