nn_lstm {torch}R Documentation

Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.

Description

For each element in the input sequence, each layer computes the following function:

Usage

nn_lstm(
  input_size,
  hidden_size,
  num_layers = 1,
  bias = TRUE,
  batch_first = FALSE,
  dropout = 0,
  bidirectional = FALSE,
  ...
)

Arguments

input_size

The number of expected features in the input x

hidden_size

The number of features in the hidden state h

num_layers

Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a ⁠stacked LSTM⁠, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1

bias

If FALSE, then the layer does not use bias weights b_ih and b_hh. Default: TRUE

batch_first

If TRUE, then the input and output tensors are provided as (batch, seq, feature). Default: FALSE

dropout

If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0

bidirectional

If TRUE, becomes a bidirectional LSTM. Default: FALSE

...

currently unused.

Details

\begin{array}{ll} \\ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{(t-1)} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ c_t = f_t c_{(t-1)} + i_t g_t \\ h_t = o_t \tanh(c_t) \\ \end{array}

where h_t is the hidden state at time t, c_t is the cell state at time t, x_t is the input at time t, h_{(t-1)} is the hidden state of the previous layer at time t-1 or the initial hidden state at time 0, and i_t, f_t, g_t, o_t are the input, forget, cell, and output gates, respectively. \sigma is the sigmoid function.

Inputs

Inputs: input, (h_0, c_0)

If ⁠(h_0, c_0)⁠ is not provided, both h_0 and c_0 default to zero.

Outputs

Outputs: output, (h_n, c_n)

Attributes

Note

All the weights and biases are initialized from \mathcal{U}(-\sqrt{k}, \sqrt{k}) where k = \frac{1}{\mbox{hidden\_size}}

Examples

if (torch_is_installed()) {
rnn <- nn_lstm(10, 20, 2)
input <- torch_randn(5, 3, 10)
h0 <- torch_randn(2, 3, 20)
c0 <- torch_randn(2, 3, 20)
output <- rnn(input, list(h0, c0))
}

[Package torch version 0.13.0 Index]