loss_categorical_focal_crossentropy {keras3}R Documentation

Computes the alpha balanced focal crossentropy loss.

Description

Use this crossentropy loss function when there are two or more label classes and if you want to handle class imbalance without using class_weights. We expect labels to be provided in a one_hot representation.

According to Lin et al., 2018, it helps to apply a focal factor to down-weight easy examples and focus more on hard examples. The general formula for the focal loss (FL) is as follows:

FL(p_t) = (1 - p_t)^gamma * log(p_t)

where p_t is defined as follows: ⁠p_t = output if y_true == 1, else 1 - output⁠

(1 - p_t)^gamma is the modulating_factor, where gamma is a focusing parameter. When gamma = 0, there is no focal effect on the cross entropy. gamma reduces the importance given to simple examples in a smooth manner.

The authors use alpha-balanced variant of focal loss (FL) in the paper: FL(p_t) = -alpha * (1 - p_t)^gamma * log(p_t)

where alpha is the weight factor for the classes. If alpha = 1, the loss won't be able to handle class imbalance properly as all classes will have the same weight. This can be a constant or a list of constants. If alpha is a list, it must have the same length as the number of classes.

The formula above can be generalized to: FL(p_t) = alpha * (1 - p_t)^gamma * CrossEntropy(y_true, y_pred)

where minus comes from CrossEntropy(y_true, y_pred) (CE).

Extending this to multi-class case is straightforward: FL(p_t) = alpha * (1 - p_t) ** gamma * CategoricalCE(y_true, y_pred)

In the snippet below, there is num_classes floating pointing values per example. The shape of both y_pred and y_true are ⁠(batch_size, num_classes)⁠.

Usage

loss_categorical_focal_crossentropy(
  y_true,
  y_pred,
  alpha = 0.25,
  gamma = 2,
  from_logits = FALSE,
  label_smoothing = 0,
  axis = -1L,
  ...,
  reduction = "sum_over_batch_size",
  name = "categorical_focal_crossentropy",
  dtype = NULL
)

Arguments

y_true

Tensor of one-hot true targets.

y_pred

Tensor of predicted targets.

alpha

A weight balancing factor for all classes, default is 0.25 as mentioned in the reference. It can be a list of floats or a scalar. In the multi-class case, alpha may be set by inverse class frequency by using compute_class_weight from sklearn.utils.

gamma

A focusing parameter, default is 2.0 as mentioned in the reference. It helps to gradually reduce the importance given to simple examples in a smooth manner. When gamma = 0, there is no focal effect on the categorical crossentropy.

from_logits

Whether output is expected to be a logits tensor. By default, we consider that output encodes a probability distribution.

label_smoothing

Float in ⁠[0, 1].⁠ When > 0, label values are smoothed, meaning the confidence on label values are relaxed. For example, if 0.1, use 0.1 / num_classes for non-target labels and 0.9 + 0.1 / num_classes for target labels.

axis

The axis along which to compute crossentropy (the features axis). Defaults to -1.

...

For forward/backward compatability.

reduction

Type of reduction to apply to the loss. In almost all cases this should be "sum_over_batch_size". Supported options are "sum", "sum_over_batch_size" or NULL.

name

Optional name for the loss instance.

dtype

The dtype of the loss's computations. Defaults to NULL, which means using config_floatx(). config_floatx() is a "float32" unless set to different value (via config_set_floatx()).

Value

Categorical focal crossentropy loss value.

Examples

y_true <- rbind(c(0, 1, 0), c(0, 0, 1))
y_pred <- rbind(c(0.05, 0.95, 0), c(0.1, 0.8, 0.1))
loss <- loss_categorical_focal_crossentropy(y_true, y_pred)
loss
## tf.Tensor([3.20583090e-05 4.66273481e-01], shape=(2), dtype=float64)

Standalone usage:

y_true <- rbind(c(0, 1, 0), c(0, 0, 1))
y_pred <- rbind(c(0.05, 0.95, 0), c(0.1, 0.8, 0.1))
# Using 'auto'/'sum_over_batch_size' reduction type.
cce <- loss_categorical_focal_crossentropy()
cce(y_true, y_pred)
## tf.Tensor(0.23315276, shape=(), dtype=float32)

# Calling with 'sample_weight'.
cce(y_true, y_pred, sample_weight = op_array(c(0.3, 0.7)))
## tf.Tensor(0.16320053, shape=(), dtype=float32)

# Using 'sum' reduction type.
cce <- loss_categorical_focal_crossentropy(reduction = "sum")
cce(y_true, y_pred)
## tf.Tensor(0.46630552, shape=(), dtype=float32)

# Using 'none' reduction type.
cce <- loss_categorical_focal_crossentropy(reduction = NULL)
cce(y_true, y_pred)
## tf.Tensor([3.2058331e-05 4.6627346e-01], shape=(2), dtype=float32)

Usage with the compile() API:

model %>% compile(
  optimizer = 'adam',
  loss = loss_categorical_focal_crossentropy())

See Also

Other losses:
Loss()
loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_hinge()
loss_cosine_similarity()
loss_ctc()
loss_dice()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()


[Package keras3 version 1.1.0 Index]