Layer {keras3}R Documentation

Define a custom Layer class.

Description

A layer is a callable object that takes as input one or more tensors and that outputs one or more tensors. It involves computation, defined in the call() method, and a state (weight variables). State can be created:

Layers are recursively composable: If you assign a Layer instance as an attribute of another Layer, the outer layer will start tracking the weights created by the inner layer. Nested layers should be instantiated in the initialize() method or build() method.

Users will just instantiate a layer and then treat it as a callable.

Usage

Layer(
  classname,
  initialize = NULL,
  call = NULL,
  build = NULL,
  get_config = NULL,
  ...,
  public = list(),
  private = list(),
  inherit = NULL,
  parent_env = parent.frame()
)

Arguments

classname

String, the name of the custom class. (Conventionally, CamelCase).

initialize, call, build, get_config

Recommended methods to implement. See description and details sections.

..., public

Additional methods or public members of the custom class.

private

Named list of R objects (typically, functions) to include in instance private environments. private methods will have all the same symbols in scope as public methods (See section "Symbols in Scope"). Each instance will have it's own private environment. Any objects in private will be invisible from the Keras framework and the Python runtime.

inherit

What the custom class will subclass. By default, the base keras class.

parent_env

The R environment that all class methods will have as a grandparent.

Value

A composing layer constructor, with similar behavior to other layer functions like layer_dense(). The first argument of the returned function will be object, enabling initialize()ing and call() the layer in one step while composing the layer with the pipe, like

layer_foo <- Layer("Foo", ....)
output <- inputs |> layer_foo()

To only initialize() a layer instance and not call() it, pass a missing or NULL value to object, or pass all arguments to initialize() by name.

layer <- layer_dense(units = 2, activation = "relu")
layer <- layer_dense(NULL, 2, activation = "relu")
layer <- layer_dense(, 2, activation = "relu")

# then you can call() the layer in a separate step
outputs <- inputs |> layer()

Symbols in scope

All R function custom methods (public and private) will have the following symbols in scope:

Attributes

We recommend that custom Layers implement the following methods:

Examples

Here's a basic example: a layer with two variables, w and b, that returns y <- (w %*% x) + b. It shows how to implement build() and call(). Variables set as attributes of a layer are tracked as weights of the layers (in layer$weights).

layer_simple_dense <- Layer(
  "SimpleDense",
  initialize = function(units = 32) {
    super$initialize()
    self$units <- units
  },

  # Create the state of the layer (weights)
  build = function(input_shape) {
    self$kernel <- self$add_weight(
      shape = shape(tail(input_shape, 1), self$units),
      initializer = "glorot_uniform",
      trainable = TRUE,
      name = "kernel"
    )
    self$bias = self$add_weight(
      shape = shape(self$units),
      initializer = "zeros",
      trainable = TRUE,
      name = "bias"
    )
  },

  # Defines the computation
  call = function(self, inputs) {
    op_matmul(inputs, self$kernel) + self$bias
  }
)

# Instantiates the layer.
# Supply missing `object` arg to skip invoking `call()` and instead return
# the Layer instance
linear_layer <- layer_simple_dense(, 4)

# This will call `build(input_shape)` and create the weights,
# and then invoke `call()`.
y <- linear_layer(op_ones(c(2, 2)))
stopifnot(length(linear_layer$weights) == 2)

# These weights are trainable, so they're listed in `trainable_weights`:
stopifnot(length(linear_layer$trainable_weights) == 2)

Besides trainable weights, updated via backpropagation during training, layers can also have non-trainable weights. These weights are meant to be updated manually during call(). Here's a example layer that computes the running sum of its inputs:

layer_compute_sum <- Layer(
  classname = "ComputeSum",

  initialize = function(input_dim) {
    super$initialize()

    # Create a non-trainable weight.
    self$total <- self$add_weight(
      shape = shape(),
      initializer = "zeros",
      trainable = FALSE,
      name = "total"
    )
  },

  call = function(inputs) {
    self$total$assign(self$total + op_sum(inputs))
    self$total
  }
)

my_sum <- layer_compute_sum(, 2)
x <- op_ones(c(2, 2))
y <- my_sum(x)

stopifnot(exprs = {
  all.equal(my_sum$weights,               list(my_sum$total))
  all.equal(my_sum$non_trainable_weights, list(my_sum$total))
  all.equal(my_sum$trainable_weights,     list())
})

Methods available

Readonly properties:

Data descriptors (Attributes):

See Also

Other layers:
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_einsum_dense()
layer_embedding()
layer_feature_space()
layer_flatten()
layer_flax_module_wrapper()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashed_crossing()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_jax_model_wrapper()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_mel_spectrogram()
layer_minimum()
layer_multi_head_attention()
layer_multiply()
layer_normalization()
layer_permute()
layer_random_brightness()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_rotation()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_tfsm()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()


[Package keras3 version 0.2.0 Index]