layer_tfsm {keras3}R Documentation

Reload a Keras model/layer that was saved via export_savedmodel().

Description

Reload a Keras model/layer that was saved via export_savedmodel().

Usage

layer_tfsm(
  object,
  filepath,
  call_endpoint = "serve",
  call_training_endpoint = NULL,
  trainable = TRUE,
  name = NULL,
  dtype = NULL
)

Arguments

object

Object to compose the layer with. A tensor, array, or sequential model.

filepath

string, the path to the SavedModel.

call_endpoint

Name of the endpoint to use as the call() method of the reloaded layer. If the SavedModel was created via export_savedmodel(), then the default endpoint name is 'serve'. In other cases it may be named 'serving_default'.

call_training_endpoint

see description

trainable

see description

name

String, name for the object

dtype

datatype (e.g., "float32").

Value

The return value depends on the value provided for the first argument. If object is:

Examples

model <- keras_model_sequential(input_shape = c(784)) |> layer_dense(10)
model |> export_savedmodel("path/to/artifact")
## Saved artifact at 'path/to/artifact'. The following endpoints are available:
##
## * Endpoint 'serve'
##   args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 784), dtype=tf.float32, name='keras_tensor')
## Output Type:
##   TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
## Captures:
##   129076520131984: TensorSpec(shape=(), dtype=tf.resource, name=None)
##   129076520126416: TensorSpec(shape=(), dtype=tf.resource, name=None)

reloaded_layer <- layer_tfsm(filepath = "path/to/artifact")
input <- random_normal(c(2, 784))
output <- reloaded_layer(input)
stopifnot(all.equal(as.array(output), as.array(model(input))))

The reloaded object can be used like a regular Keras layer, and supports training/fine-tuning of its trainable weights. Note that the reloaded object retains none of the internal structure or custom methods of the original object – it's a brand new layer created around the saved function.

Limitations:

See Also

Other layers:
Layer()
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_einsum_dense()
layer_embedding()
layer_feature_space()
layer_flatten()
layer_flax_module_wrapper()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashed_crossing()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_jax_model_wrapper()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_mel_spectrogram()
layer_minimum()
layer_multi_head_attention()
layer_multiply()
layer_normalization()
layer_permute()
layer_random_brightness()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_rotation()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()

Other saving and loading functions:
export_savedmodel.keras.src.models.model.Model()
load_model()
load_model_weights()
register_keras_serializable()
save_model()
save_model_config()
save_model_weights()
with_custom_object_scope()


[Package keras3 version 1.1.0 Index]