adapt {keras}R Documentation

Fits the state of the preprocessing layer to the data being passed

Description

Fits the state of the preprocessing layer to the data being passed

Usage

adapt(object, data, ..., batch_size = NULL, steps = NULL)

Arguments

object

Preprocessing layer object

data

The data to train on. It can be passed either as a tf.data.Dataset or as an R array.

...

Used for forwards and backwards compatibility. Passed on to the underlying method.

batch_size

Integer or NULL. Number of asamples per state update. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).

steps

Integer or NULL. Total number of steps (batches of samples) When training with input tensors such as TensorFlow data tensors, the default NULL is equal to the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined. If x is a tf.data.Dataset, and steps is NULL, the epoch will run until the input dataset is exhausted. When passing an infinitely repeating dataset, you must specify the steps argument. This argument is not supported with array inputs.

Details

After calling adapt on a layer, a preprocessing layer's state will not update during training. In order to make preprocessing layers efficient in any distribution context, they are kept constant with respect to any compiled tf.Graphs that call the layer. This does not affect the layer use when adapting each layer only once, but if you adapt a layer multiple times you will need to take care to re-compile any compiled functions as follows:

keras_model example with multiple adapts:

layer <- layer_normalization(axis=NULL)
adapt(layer, c(0, 2))
model <- keras_model_sequential(layer)
predict(model, c(0, 1, 2)) # [1] -1  0  1

adapt(layer, c(-1, 1))
compile(model)  # This is needed to re-compile model.predict!
predict(model, c(0, 1, 2)) # [1] 0 1 2

tf.data.Dataset example with multiple adapts:

layer <- layer_normalization(axis=NULL)
adapt(layer, c(0, 2))
input_ds <- tfdatasets::range_dataset(0, 3)
normalized_ds <- input_ds %>%
  tfdatasets::dataset_map(layer)
str(reticulate::iterate(normalized_ds))
# List of 3
#  $ :tf.Tensor([-1.], shape=(1,), dtype=float32)
#  $ :tf.Tensor([0.], shape=(1,), dtype=float32)
#  $ :tf.Tensor([1.], shape=(1,), dtype=float32)
adapt(layer, c(-1, 1))
normalized_ds <- input_ds %>%
  tfdatasets::dataset_map(layer) # Re-map over the input dataset.
str(reticulate::iterate(normalized_ds$as_numpy_iterator()))
# List of 3
#  $ : num [1(1d)] -1
#  $ : num [1(1d)] 0
#  $ : num [1(1d)] 1

See Also


[Package keras version 2.15.0 Index]