attention_bahdanau {tfaddons}R Documentation

Bahdanau Attention

Description

Implements Bahdanau-style (additive) attention

Usage

attention_bahdanau(
  object,
  units,
  memory = NULL,
  memory_sequence_length = NULL,
  normalize = FALSE,
  probability_fn = "softmax",
  kernel_initializer = "glorot_uniform",
  dtype = NULL,
  name = "BahdanauAttention",
  ...
)

Arguments

object

Model or layer object

units

The depth of the query mechanism.

memory

The memory to query; usually the output of an RNN encoder. This tensor should be shaped [batch_size, max_time, ...].

memory_sequence_length

(optional): Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths.

normalize

boolean. Whether to normalize the energy term.

probability_fn

(optional) string, the name of function to convert the attention score to probabilities. The default is softmax which is tf.nn.softmax. Other options is hardmax, which is hardmax() within this module. Any other value will result into validation error. Default to use softmax.

kernel_initializer

(optional), the name of the initializer for the attention kernel.

dtype

The data type for the query and memory layers of the attention mechanism.

name

Name to use when creating ops.

...

A list that contains other common arguments for layer creation.

Details

This attention has two forms. The first is Bahdanau attention, as described in: Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. "Neural Machine Translation by Jointly Learning to Align and Translate." ICLR 2015. https://arxiv.org/abs/1409.0473 The second is the normalized form. This form is inspired by the weight normalization article: Tim Salimans, Diederik P. Kingma. "Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks." https://arxiv.org/abs/1602.07868 To enable the second form, construct the object with parameter 'normalize=TRUE'.

Value

None


[Package tfaddons version 0.10.0 Index]