Implement Focal Loss for Multi Label Classification in TensorFlow

By | June 30, 2021

Focal loss is a good method to improve the model performance for imbalance multi label classification. In this tutorial, we will implement it using tensorflow.

Focal Loss

Focal loss was introduced by Tsung-Yi Lin et al in paper << Focal Loss for Dense Object Detection>>, it is used to handle imbalance prediction in object detection.

Focal loss can be defined as:


Here \(p_t\) is the predicted score, \(r\) is a hyper parameter.

There is a comparison between focal loss and cross entropy loss.

Understand the Gradient of Cross Entropy Loss Function

a comparison between focal loss and cross entropy loss

Understand Exponential Function in Machine Learning

How to implement focal loss in tensorflow?

Focal loss can be used in multi label classification, we can use tensorflow to create it. Here is an example code:

    def focal_loss(logits, labels, gamma=2.0, alpha=4.0):
        epsilon = 1.e-9
        probs = tf.nn.softmax(logits)
        y_pred = tf.clip_by_value(probs, epsilon, 1. - epsilon)

        ce = tf.multiply(labels, -tf.log(y_pred))
        weight = tf.multiply(labels, tf.pow(tf.subtract(1., y_pred), gamma))
        fl = tf.multiply(alpha, tf.multiply(weight, ce))
        reduced_fl = tf.reduce_max(fl, axis=1)
        # reduced_fl = tf.reduce_sum(fl, axis=1)  # same as reduce_max
        return tf.reduce_mean(reduced_fl)

As to this code, you should notice:

labels is the gold truth, which is batch_size * class_num.

You can use this loss as follows:

self.scores_intent = tf.matmul(self.doc_output, self.weights['softmax_intent']) + self.biases['softmax_intent']
loss_mean = self.focal_loss(logits=self.scores_intent, labels=self.input_y) # self.input_y is 64*10, batch_size = 64

Leave a Reply

Your email address will not be published. Required fields are marked *