from numbers import Real
from typing import Optional
import numpy as np
import mygrad._utils.graph_tracking as _tracking
from mygrad.nnet.activations import softmax
from mygrad.operation_base import Operation
from mygrad.tensor_base import Tensor, asarray
from mygrad.typing import ArrayLike
from ._utils import check_loss_inputs
__all__ = ["softmax_focal_loss", "focal_loss"]
class FocalLoss(Operation):
r"""Returns the per-datum focal loss as described in https://arxiv.org/abs/1708.02002
which is given by -ɑ(1-p)ˠlog(p).
Extended Description
--------------------
The focal loss is given by
.. math::
\frac{1}{N}\sum\limits_{1}^{N}-\alpha \hat{y}_i(1-p_i)^\gamma\log(p_i)
where :math:`N` is the number of elements in `x` and `y` and :math:`\hat{y}_i` is
one where :math:`i` is the label of the element :math:`y_i` and 0 elsewhere. That is,
if the label :math:`y_k` is 1 and there are four possible label values, then
:math:`\hat{y}_k = (0, 1, 0, 0)`.
It is recommended in the paper that you normalize by the number of foreground samples.
"""
def __call__(self, class_probs, targets, alpha, gamma):
"""
Parameters
----------
class_probs : mygrad.Tensor, shape=(N, C)
The C class scores for each of the N pieces of data.
targets : Union[mygrad.Tensor, ArrayLike], shape=(N,)
The correct class indices, in [0, C), for each datum.
alpha : Real
The ɑ weighting factor in the loss formulation.
gamma : Real
The ɣ focusing parameter.
Returns
-------
numpy.ndarray
The per-datum focal loss.
"""
if isinstance(targets, Tensor): # pragma: nocover
targets = targets.data
check_loss_inputs(class_probs, targets)
self.variables = (class_probs,)
self.label_locs = (range(len(class_probs)), targets)
class_probs = asarray(class_probs)
pc = class_probs[self.label_locs]
one_m_pc = np.clip(1 - pc, a_min=0, a_max=1)
log_pc = np.log(pc)
one_m_pc_gamma = one_m_pc**gamma
loss = -(alpha * one_m_pc_gamma * log_pc)
if not _tracking.TRACK_GRAPH:
return loss
self.back = np.zeros(class_probs.shape, dtype=np.float64)
if np.isclose(gamma, 0, atol=1e-15):
self.back[self.label_locs] -= alpha / pc
return loss
# dL/dp = -alpha * ( (1 - p)**g / p - g * (1 - p)**(g - 1) * log(p) )
#
# term 1: (1 - p)**g / p
term1 = one_m_pc_gamma / pc # (1 - p)**g / p
# term 2: - g * (1 - p)**(g - 1) * log(p)
if np.isclose(gamma, 1, rtol=1e-15):
term2 = -log_pc
elif gamma < 1:
# For g < 1 and p -> 1, the 2nd term -> 0 via L'Hôpital's rule
term2 = np.zeros(pc.shape, dtype=class_probs.dtype)
pc_not_1 = ~np.isclose(one_m_pc, 0, atol=1e-25)
term2[pc_not_1] = (
-gamma * one_m_pc[pc_not_1] ** (gamma - 1) * log_pc[pc_not_1]
)
else:
term2 = -gamma * one_m_pc ** (gamma - 1) * log_pc
self.back[self.label_locs] -= alpha * (term1 + term2)
return loss
def backward_var(self, grad, index, **kwargs):
self.back[self.label_locs] *= grad
return self.back
[docs]def focal_loss(
class_probs: ArrayLike,
targets: ArrayLike,
*,
alpha: float = 1,
gamma: float = 0,
constant: Optional[bool] = None,
) -> Tensor:
r"""Return the per-datum focal loss.
Parameters
----------
class_probs : ArrayLike, shape=(N, C)
The C class probabilities for each of the N pieces of data.
Each value is expected to lie on (0, 1]
targets : ArrayLike, shape=(N,)
The correct class indices, in [0, C), for each datum.
alpha : Real, optional (default=1)
The ɑ weighting factor in the loss formulation.
gamma : Real, optional (default=0)
The ɣ focusing parameter. Note that for Ɣ=0 and ɑ=1, this is cross-entropy loss.
Must be a non-negative value.
constant : Optional[bool]
If ``True``, the returned tensor is a constant (it
does not back-propagate a gradient)
Returns
-------
mygrad.Tensor, shape=(N,)
The per-datum focal loss.
Notes
-----
The formulation for the focal loss introduced in https://arxiv.org/abs/1708.02002.
It is given by -ɑ(1-p)ˠlog(p).
The focal loss for datum-:math:`i` is given by
.. math::
-\alpha \hat{y}_i(1-p_i)^\gamma\log(p_i)
where :math:`\hat{y}_i` is one in correspondence to the label associated with the
datum and 0 elsewhere. That is, if the label :math:`y_k` is 2 and
there are four possible label values, then :math:`\hat{y}_k = (0, 0, 1, 0)`.
It is recommended in the paper that you normalize by the number of foreground samples.
"""
if not isinstance(gamma, Real) or gamma < 0:
raise ValueError(f"`gamma` must be a non-negative number, got: {gamma}")
return Tensor._op(
FocalLoss, class_probs, op_args=(targets, alpha, gamma), constant=constant
)
[docs]def softmax_focal_loss(
scores: ArrayLike,
targets: ArrayLike,
*,
alpha: float = 1,
gamma: float = 0,
constant: Optional[bool] = None,
) -> Tensor:
r"""
Applies the softmax normalization to the input scores before computing the
per-datum focal loss.
Parameters
----------
scores : ArrayLike, shape=(N, C)
The C class scores for each of the N pieces of data.
targets : ArrayLike, shape=(N,)
The correct class indices, in [0, C), for each datum.
alpha : Real, optional (default=1)
The ɑ weighting factor in the loss formulation.
gamma : Real, optional (default=0)
The ɣ focusing parameter. Note that for Ɣ=0 and ɑ=1, this is cross-entropy loss.
Must be a non-negative value.
constant : Optional[bool]
If ``True``, the returned tensor is a constant (it
does not back-propagate a gradient)
Returns
-------
mygrad.Tensor, shape=(N,)
The per-datum focal loss.
Notes
-----
The formulation for the focal loss introduced in https://arxiv.org/abs/1708.02002.
It is given by -ɑ(1-p)ˠlog(p).
The focal loss for datum-:math:`i` is given by
.. math::
-\alpha \hat{y}_i(1-p_i)^\gamma\log(p_i)
where :math:`\hat{y}_i` is one in correspondence to the label associated with the
datum and 0 elsewhere. That is, if the label :math:`y_k` is 2 and
there are four possible label values, then :math:`\hat{y}_k = (0, 0, 1, 0)`.
It is recommended in the paper that you normalize by the number of foreground samples.
"""
return focal_loss(
softmax(scores), targets=targets, alpha=alpha, gamma=gamma, constant=constant
)