Loss {keras3}R Documentation

Subclass the base Loss class

Description

Use this to define a custom loss class. Note, in most cases you do not need to subclass Loss to define a custom loss: you can also pass a bare R function, or a named R function defined with custom_metric(), as a loss function to compile().

Usage

Loss(
  classname,
  call = NULL,
  ...,
  public = list(),
  private = list(),
  inherit = NULL,
  parent_env = parent.frame()
)

Arguments

classname

String, the name of the custom class. (Conventionally, CamelCase).

call
function(y_true, y_pred)

Method to be implemented by subclasses: Function that contains the logic for loss calculation using y_true, y_pred.

..., public

Additional methods or public members of the custom class.

private

Named list of R objects (typically, functions) to include in instance private environments. private methods will have all the same symbols in scope as public methods (See section "Symbols in Scope"). Each instance will have it's own private environment. Any objects in private will be invisible from the Keras framework and the Python runtime.

inherit

What the custom class will subclass. By default, the base keras class.

parent_env

The R environment that all class methods will have as a grandparent.

Details

Example subclass implementation:

loss_custom_mse <- Loss(
  classname = "CustomMeanSquaredError",
  call = function(y_true, y_pred) {
    op_mean(op_square(y_pred - y_true), axis = -1)
  }
)

# Usage in compile()
model <- keras_model_sequential(input_shape = 10) |> layer_dense(10)
model |> compile(loss = loss_custom_mse())

# Standalone usage
mse <- loss_custom_mse(name = "my_custom_mse_instance")

y_true <- op_arange(20) |> op_reshape(c(4, 5))
y_pred <- op_arange(20) |> op_reshape(c(4, 5)) * 2
(loss <- mse(y_true, y_pred))
## tf.Tensor(123.5, shape=(), dtype=float32)

loss2 <- (y_pred - y_true)^2 |>
  op_mean(axis = -1) |>
  op_mean()

stopifnot(all.equal(as.array(loss), as.array(loss2)))

sample_weight <-array(c(.25, .25, 1, 1))
(weighted_loss <- mse(y_true, y_pred, sample_weight = sample_weight))
## tf.Tensor(112.8125, shape=(), dtype=float32)

weighted_loss2 <- (y_true - y_pred)^2 |>
  op_mean(axis = -1) |>
  op_multiply(sample_weight) |>
  op_mean()

stopifnot(all.equal(as.array(weighted_loss),
                    as.array(weighted_loss2)))

Value

A function that returns Loss instances, similar to the builtin loss functions.

Methods defined by base Loss class:

Symbols in scope

All R function custom methods (public and private) will have the following symbols in scope:

See Also

Other losses:
loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_focal_crossentropy()
loss_categorical_hinge()
loss_cosine_similarity()
loss_ctc()
loss_dice()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()


[Package keras3 version 1.1.0 Index]