Metric {keras3}R Documentation

Subclass the base Metric class

Description

A Metric object encapsulates metric logic and state that can be used to track model performance during training. It is what is returned by the family of metric functions that start with prefix ⁠metric_*⁠, as well as what is returned by custom metrics defined with Metric().

Usage

Metric(
  classname,
  initialize = NULL,
  update_state = NULL,
  result = NULL,
  ...,
  public = list(),
  private = list(),
  inherit = NULL,
  parent_env = parent.frame()
)

Arguments

classname

String, the name of the custom class. (Conventionally, CamelCase).

initialize, update_state, result

Recommended methods to implement. See description section.

..., public

Additional methods or public members of the custom class.

private

Named list of R objects (typically, functions) to include in instance private environments. private methods will have all the same symbols in scope as public methods (See section "Symbols in Scope"). Each instance will have it's own private environment. Any objects in private will be invisible from the Keras framework and the Python runtime.

inherit

What the custom class will subclass. By default, the base keras class.

parent_env

The R environment that all class methods will have as a grandparent.

Value

A function that returns Metric instances, similar to the builtin metric functions.

Examples

Usage with compile():

model |> compile(
  optimizer = 'sgd',
  loss = 'mse',
  metrics = c(metric_SOME_METRIC(), metric_SOME_OTHER_METRIC())
)

Standalone usage:

m <- metric_SOME_METRIC()
for (e in seq(epochs)) {
  for (i in seq(train_steps)) {
    c(y_true, y_pred, sample_weight = NULL) %<-% ...
    m$update_state(y_true, y_pred, sample_weight)
  }
  cat('Final epoch result: ', as.numeric(m$result()), "\n")
  m$reset_state()
}

Full Examples

Usage with compile():

model <- keras_model_sequential()
model |>
  layer_dense(64, activation = "relu") |>
  layer_dense(64, activation = "relu") |>
  layer_dense(10, activation = "softmax")
model |>
  compile(optimizer = optimizer_rmsprop(0.01),
          loss = loss_categorical_crossentropy(),
          metrics = metric_categorical_accuracy())

data <- random_uniform(c(1000, 32))
labels <- random_uniform(c(1000, 10))

model |> fit(data, labels, verbose = 0)

To be implemented by subclasses (custom metrics):

Example subclass implementation:

metric_binary_true_positives <- Metric(
  classname = "BinaryTruePositives",

  initialize = function(name = 'binary_true_positives', ...) {
    super$initialize(name = name, ...)
    self$true_positives <-
      self$add_weight(shape = shape(),
                      initializer = 'zeros',
                      name = 'true_positives')
  },

  update_state = function(y_true, y_pred, sample_weight = NULL) {
    y_true <- op_cast(y_true, "bool")
    y_pred <- op_cast(y_pred, "bool")

    values <- y_true & y_pred # `&` calls op_logical_and()
    values <- op_cast(values, self$dtype)
    if (!is.null(sample_weight)) {
      sample_weight <- op_cast(sample_weight, self$dtype)
      sample_weight <- op_broadcast_to(sample_weight, shape(values))
      values <- values * sample_weight # `*` calls op_multiply()
    }
    self$true_positives$assign(self$true_positives + op_sum(values))
  },

  result = function() {
    self$true_positives
  }
)
model <- keras_model_sequential(input_shape = 32) |> layer_dense(10)
model |> compile(loss = loss_binary_crossentropy(),
                 metrics = list(metric_binary_true_positives()))
model |> fit(data, labels, verbose = 0)

Methods defined by the base Metric class:

Readonly properties

Symbols in scope

All R function custom methods (public and private) will have the following symbols in scope:

See Also

Other metrics:
custom_metric()
metric_auc()
metric_binary_accuracy()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_binary_iou()
metric_categorical_accuracy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_cosine_similarity()
metric_f1_score()
metric_false_negatives()
metric_false_positives()
metric_fbeta_score()
metric_hinge()
metric_huber()
metric_iou()
metric_kl_divergence()
metric_log_cosh()
metric_log_cosh_error()
metric_mean()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_iou()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_mean_wrapper()
metric_one_hot_iou()
metric_one_hot_mean_iou()
metric_poisson()
metric_precision()
metric_precision_at_recall()
metric_r2_score()
metric_recall()
metric_recall_at_precision()
metric_root_mean_squared_error()
metric_sensitivity_at_specificity()
metric_sparse_categorical_accuracy()
metric_sparse_categorical_crossentropy()
metric_sparse_top_k_categorical_accuracy()
metric_specificity_at_sensitivity()
metric_squared_hinge()
metric_sum()
metric_top_k_categorical_accuracy()
metric_true_negatives()
metric_true_positives()


[Package keras3 version 1.1.0 Index]