October 19, 2024

About

https://arxiv.org/abs/1708.02002

The Focal Loss is a loss function designed to solve class imbalance problems, especially when some classes appear much more frequently than others. The function reduces the penalty for samples that are easy to classify correctly and increases the penalty for samples that are easy to mistake or difficult to classify. This makes the model pay more attention to difficult cases.

Focal Loss

Focal Loss was introduced as an extension to binary classification (Binary Cross-Entropy Loss). Its formula is expressed as follows:

\text{Focal Loss} = - \alpha_t (1 - p_t)^\gamma \log(p_t)

According to Keras document(https://keras.io/api/keras_cv/losses/focal_loss/),

  • alpha: a float value between 0 and 1 representing a weighting factor used to deal with class imbalance. Positive classes and negative classes have alpha and (1 – alpha) as their weighting factors respectively. Defaults to 0.25.
  • gamma: a positive float value representing the tunable focusing parameter, defaults to 2.

Implementation in PyTorch is here.

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)  # pt is the probability of being classified correctly
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return torch.mean(F_loss)

criterion = FocalLoss(alpha=0.25, gamma=2.0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()
for data, target in train_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()