callback_lambda {keras3}R Documentation

Callback for creating simple, custom callbacks on-the-fly.

Description

This callback is constructed with anonymous functions that will be called at the appropriate time (during ⁠Model.{fit | evaluate | predict}⁠). Note that the callbacks expects positional arguments, as:

Usage

callback_lambda(
  on_epoch_begin = NULL,
  on_epoch_end = NULL,
  on_train_begin = NULL,
  on_train_end = NULL,
  on_train_batch_begin = NULL,
  on_train_batch_end = NULL,
  ...
)

Arguments

on_epoch_begin

called at the beginning of every epoch.

on_epoch_end

called at the end of every epoch.

on_train_begin

called at the beginning of model training.

on_train_end

called at the end of model training.

on_train_batch_begin

called at the beginning of every train batch.

on_train_batch_end

called at the end of every train batch.

...

Any function in Callback() that you want to override by passing ⁠function_name = function⁠. For example, callback_lambda(.., on_train_end = train_end_fn). The custom function needs to have same arguments as the ones defined in Callback().

Value

A Callback instance that can be passed to fit.keras.src.models.model.Model().

Examples

# Print the batch number at the beginning of every batch.
batch_print_callback <- callback_lambda(
  on_train_batch_begin = function(batch, logs) {
    print(batch)
  }
)

# Stream the epoch loss to a file in new-line delimited JSON format
# (one valid JSON object per line)
json_log <- file('loss_log.json', open = 'wt')
json_logging_callback <- callback_lambda(
  on_epoch_end = function(epoch, logs) {
    jsonlite::write_json(
      list(epoch = epoch, loss = logs$loss),
      json_log,
      append = TRUE
    )
  },
  on_train_end = function(logs) {
    close(json_log)
  }
)

# Terminate some processes after having finished model training.
processes <- ...
cleanup_callback <- callback_lambda(
  on_train_end = function(logs) {
    for (p in processes) {
      if (is_alive(p)) {
        terminate(p)
      }
    }
  }
)

model %>% fit(
  ...,
  callbacks = list(
    batch_print_callback,
    json_logging_callback,
    cleanup_callback
  )
)

See Also

Other callbacks:
Callback()
callback_backup_and_restore()
callback_csv_logger()
callback_early_stopping()
callback_learning_rate_scheduler()
callback_model_checkpoint()
callback_reduce_lr_on_plateau()
callback_remote_monitor()
callback_swap_ema_weights()
callback_tensorboard()
callback_terminate_on_nan()


[Package keras3 version 1.1.0 Index]