Loss {keras3} | R Documentation |
Subclass the base Loss
class
Description
Use this to define a custom loss class. Note, in most cases you do not need
to subclass Loss
to define a custom loss: you can also pass a bare R
function, or a named R function defined with custom_metric()
, as a loss
function to compile()
.
Usage
Loss(
classname,
call = NULL,
...,
public = list(),
private = list(),
inherit = NULL,
parent_env = parent.frame()
)
Arguments
classname |
String, the name of the custom class. (Conventionally, CamelCase). |
call |
function(y_true, y_pred) Method to be implemented by subclasses:
Function that contains the logic for loss calculation using
|
... , public |
Additional methods or public members of the custom class. |
private |
Named list of R objects (typically, functions) to include in
instance private environments. |
inherit |
What the custom class will subclass. By default, the base keras class. |
parent_env |
The R environment that all class methods will have as a grandparent. |
Details
Example subclass implementation:
loss_custom_mse <- Loss( classname = "CustomMeanSquaredError", call = function(y_true, y_pred) { op_mean(op_square(y_pred - y_true), axis = -1) } ) # Usage in compile() model <- keras_model_sequential(input_shape = 10) |> layer_dense(10) model |> compile(loss = loss_custom_mse()) # Standalone usage mse <- loss_custom_mse(name = "my_custom_mse_instance") y_true <- op_arange(20) |> op_reshape(c(4, 5)) y_pred <- op_arange(20) |> op_reshape(c(4, 5)) * 2 (loss <- mse(y_true, y_pred))
## tf.Tensor(123.5, shape=(), dtype=float32)
loss2 <- (y_pred - y_true)^2 |> op_mean(axis = -1) |> op_mean() stopifnot(all.equal(as.array(loss), as.array(loss2))) sample_weight <-array(c(.25, .25, 1, 1)) (weighted_loss <- mse(y_true, y_pred, sample_weight = sample_weight))
## tf.Tensor(112.8125, shape=(), dtype=float32)
weighted_loss2 <- (y_true - y_pred)^2 |> op_mean(axis = -1) |> op_multiply(sample_weight) |> op_mean() stopifnot(all.equal(as.array(weighted_loss), as.array(weighted_loss2)))
Value
A function that returns Loss
instances, similar to the
builtin loss functions.
Methods defined by base Loss
class:
-
initialize(name=NULL, reduction="sum_over_batch_size", dtype=NULL)
Args:
-
name
-
reduction
: Valid values are one of{"sum_over_batch_size", "sum", NULL, "none"}
-
dtype
-
-
__call__(y_true, y_pred, sample_weight=NULL)
Call the loss instance as a function, optionally with
sample_weight
. -
get_config()
Symbols in scope
All R function custom methods (public and private) will have the following symbols in scope:
-
self
: The custom class instance. -
super
: The custom class superclass. -
private
: An R environment specific to the class instance. Any objects assigned here are invisible to the Keras framework. -
__class__
andas.symbol(classname)
: the custom class type object.
See Also
Other losses:
loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_focal_crossentropy()
loss_categorical_hinge()
loss_cosine_similarity()
loss_ctc()
loss_dice()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()