loss_categorical_focal_crossentropy {keras3} | R Documentation |
Computes the alpha balanced focal crossentropy loss.
Description
Use this crossentropy loss function when there are two or more label
classes and if you want to handle class imbalance without using
class_weights
. We expect labels to be provided in a one_hot
representation.
According to Lin et al., 2018, it helps to apply a focal factor to down-weight easy examples and focus more on hard examples. The general formula for the focal loss (FL) is as follows:
FL(p_t) = (1 - p_t)^gamma * log(p_t)
where p_t
is defined as follows:
p_t = output if y_true == 1, else 1 - output
(1 - p_t)^gamma
is the modulating_factor
, where gamma
is a focusing
parameter. When gamma
= 0, there is no focal effect on the cross entropy.
gamma
reduces the importance given to simple examples in a smooth manner.
The authors use alpha-balanced variant of focal loss (FL) in the paper:
FL(p_t) = -alpha * (1 - p_t)^gamma * log(p_t)
where alpha
is the weight factor for the classes. If alpha
= 1, the
loss won't be able to handle class imbalance properly as all
classes will have the same weight. This can be a constant or a list of
constants. If alpha is a list, it must have the same length as the number
of classes.
The formula above can be generalized to:
FL(p_t) = alpha * (1 - p_t)^gamma * CrossEntropy(y_true, y_pred)
where minus comes from CrossEntropy(y_true, y_pred)
(CE).
Extending this to multi-class case is straightforward:
FL(p_t) = alpha * (1 - p_t) ** gamma * CategoricalCE(y_true, y_pred)
In the snippet below, there is num_classes
floating pointing values per
example. The shape of both y_pred
and y_true
are
(batch_size, num_classes)
.
Usage
loss_categorical_focal_crossentropy(
y_true,
y_pred,
alpha = 0.25,
gamma = 2,
from_logits = FALSE,
label_smoothing = 0,
axis = -1L,
...,
reduction = "sum_over_batch_size",
name = "categorical_focal_crossentropy",
dtype = NULL
)
Arguments
y_true |
Tensor of one-hot true targets. |
y_pred |
Tensor of predicted targets. |
alpha |
A weight balancing factor for all classes, default is |
gamma |
A focusing parameter, default is |
from_logits |
Whether |
label_smoothing |
Float in |
axis |
The axis along which to compute crossentropy (the features
axis). Defaults to |
... |
For forward/backward compatability. |
reduction |
Type of reduction to apply to the loss. In almost all cases
this should be |
name |
Optional name for the loss instance. |
dtype |
The dtype of the loss's computations. Defaults to |
Value
Categorical focal crossentropy loss value.
Examples
y_true <- rbind(c(0, 1, 0), c(0, 0, 1)) y_pred <- rbind(c(0.05, 0.95, 0), c(0.1, 0.8, 0.1)) loss <- loss_categorical_focal_crossentropy(y_true, y_pred) loss
## tf.Tensor([3.20583090e-05 4.66273481e-01], shape=(2), dtype=float64)
Standalone usage:
y_true <- rbind(c(0, 1, 0), c(0, 0, 1)) y_pred <- rbind(c(0.05, 0.95, 0), c(0.1, 0.8, 0.1)) # Using 'auto'/'sum_over_batch_size' reduction type. cce <- loss_categorical_focal_crossentropy() cce(y_true, y_pred)
## tf.Tensor(0.23315276, shape=(), dtype=float32)
# Calling with 'sample_weight'. cce(y_true, y_pred, sample_weight = op_array(c(0.3, 0.7)))
## tf.Tensor(0.16320053, shape=(), dtype=float32)
# Using 'sum' reduction type. cce <- loss_categorical_focal_crossentropy(reduction = "sum") cce(y_true, y_pred)
## tf.Tensor(0.46630552, shape=(), dtype=float32)
# Using 'none' reduction type. cce <- loss_categorical_focal_crossentropy(reduction = NULL) cce(y_true, y_pred)
## tf.Tensor([3.2058331e-05 4.6627346e-01], shape=(2), dtype=float32)
Usage with the compile()
API:
model %>% compile( optimizer = 'adam', loss = loss_categorical_focal_crossentropy())
See Also
Other losses:
Loss()
loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_hinge()
loss_cosine_similarity()
loss_ctc()
loss_dice()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()