layer_jax_model_wrapper {keras3}R Documentation

Keras Layer that wraps a JAX model.

Description

This layer enables the use of JAX components within Keras when using JAX as the backend for Keras.

Usage

layer_jax_model_wrapper(
  object,
  call_fn,
  init_fn = NULL,
  params = NULL,
  state = NULL,
  seed = NULL,
  ...
)

Arguments

object

Object to compose the layer with. A tensor, array, or sequential model.

call_fn

The function to call the model. See description above for the list of arguments it takes and the outputs it returns.

init_fn

the function to call to initialize the model. See description above for the list of arguments it takes and the ouputs it returns. If NULL, then params and/or state must be provided.

params

A PyTree containing all the model trainable parameters. This allows passing trained parameters or controlling the initialization. If both params and state are NULL, init_fn() is called at build time to initialize the trainable parameters of the model.

state

A PyTree containing all the model non-trainable state. This allows passing learned state or controlling the initialization. If both params and state are NULL, and call_fn() takes a state argument, then init_fn() is called at build time to initialize the non-trainable state of the model.

seed

Seed for random number generator. Optional.

...

For forward/backward compatability.

Value

The return value depends on the value provided for the first argument. If object is:

Model function

This layer accepts JAX models in the form of a function, call_fn(), which must take the following arguments with these exact names:

The inputs argument is mandatory. Inputs to the model must be provided via a single argument. If the JAX model takes multiple inputs as separate arguments, they must be combined into a single structure, for instance in a tuple() or a dict().

Model weights initialization

The initialization of the params and state of the model can be handled by this layer, in which case the init_fn() argument must be provided. This allows the model to be initialized dynamically with the right shape. Alternatively, and if the shape is known, the params argument and optionally the state argument can be used to create an already initialized model.

The init_fn() function, if provided, must take the following arguments with these exact names:

Models with non-trainable state

For JAX models that have non-trainable state:

This code shows a possible combination of call_fn() and init_fn() signatures for a model with non-trainable state. In this example, the model has a training argument and an rng argument in call_fn().

stateful_call <- function(params, state, rng, inputs, training) {
  outputs <- ....
  new_state <- ....
  tuple(outputs, new_state)
}

stateful_init <- function(rng, inputs) {
  initial_params <- ....
  initial_state <- ....
  tuple(initial_params, initial_state)
}

Models without non-trainable state

For JAX models with no non-trainable state:

This code shows a possible combination of call_fn() and init_fn() signatures for a model without non-trainable state. In this example, the model does not have a training argument and does not have an rng argument in call_fn().

stateful_call <- function(pparams, inputs) {
  outputs <- ....
  outputs
}

stateful_init <- function(rng, inputs) {
  initial_params <- ....
  initial_params
}

Conforming to the required signature

If a model has a different signature than the one required by JaxLayer, one can easily write a wrapper method to adapt the arguments. This example shows a model that has multiple inputs as separate arguments, expects multiple RNGs in a dict, and has a deterministic argument with the opposite meaning of training. To conform, the inputs are combined in a single structure using a tuple, the RNG is split and used the populate the expected dict, and the Boolean flag is negated:

jax <- import("jax")
my_model_fn <- function(params, rngs, input1, input2, deterministic) {
  ....
  if (!deterministic) {
    dropout_rng <- rngs$dropout
    keep <- jax$random$bernoulli(dropout_rng, dropout_rate, x$shape)
    x <- jax$numpy$where(keep, x / dropout_rate, 0)
    ....
  }
  ....
  return(outputs)
}

my_model_wrapper_fn <- function(params, rng, inputs, training) {
  c(input1, input2) %<-% inputs
  c(rng1, rng2) %<-% jax$random$split(rng)
  rngs <-  list(dropout = rng1, preprocessing = rng2)
  deterministic <-  !training
  my_model_fn(params, rngs, input1, input2, deterministic)
}

keras_layer <- layer_jax_model_wrapper(call_fn = my_model_wrapper_fn,
                                       params = initial_params)

Usage with Haiku modules

JaxLayer enables the use of Haiku components in the form of haiku.Module. This is achieved by transforming the module per the Haiku pattern and then passing module.apply in the call_fn parameter and module.init in the init_fn parameter if needed.

If the model has non-trainable state, it should be transformed with haiku.transform_with_state. If the model has no non-trainable state, it should be transformed with haiku.transform. Additionally, and optionally, if the module does not use RNGs in "apply", it can be transformed with haiku.without_apply_rng.

The following example shows how to create a JaxLayer from a Haiku module that uses random number generators via hk.next_rng_key() and takes a training positional argument:

# reticulate::py_install("haiku", "r-keras")
hk <- import("haiku")
MyHaikuModule(hk$Module) \%py_class\% {

  `__call__` <- \(self, x, training) {
    x <- hk$Conv2D(32L, tuple(3L, 3L))(x)
    x <- jax$nn$relu(x)
    x <- hk$AvgPool(tuple(1L, 2L, 2L, 1L),
                    tuple(1L, 2L, 2L, 1L), "VALID")(x)
    x <- hk$Flatten()(x)
    x <- hk$Linear(200L)(x)
    if (training)
      x <- hk$dropout(rng = hk$next_rng_key(), rate = 0.3, x = x)
    x <- jax$nn$relu(x)
    x <- hk$Linear(10L)(x)
    x <- jax$nn$softmax(x)
    x
  }

}

my_haiku_module_fn <- function(inputs, training) {
  module <- MyHaikuModule()
  module(inputs, training)
}

transformed_module <- hk$transform(my_haiku_module_fn)

keras_layer <-
  layer_jax_model_wrapper(call_fn = transformed_module$apply,
                          init_fn = transformed_module$init)

See Also

Other wrapping layers:
layer_flax_module_wrapper()
layer_torch_module_wrapper()

Other layers:
Layer()
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_einsum_dense()
layer_embedding()
layer_feature_space()
layer_flatten()
layer_flax_module_wrapper()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashed_crossing()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_mel_spectrogram()
layer_minimum()
layer_multi_head_attention()
layer_multiply()
layer_normalization()
layer_permute()
layer_random_brightness()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_rotation()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_tfsm()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()


[Package keras3 version 1.1.0 Index]