Implement GE2E Loss in TensorFlow – TensorFlow Tutorial

By | March 22, 2022

In this tutorial, we will introduce how to implement GE2E loss function using tensorflow, which is very useful when you are building a voiceprint recognition model.

What is GE2E loss?

To understand Ge2E loss function, you can view this tutorial:

An Introduction to GE2E Loss for Beginners – Deep Learning Tutorial

How to implement GE2E loss in tensorflow?

Here is an example code:

import tensorflow as tf
import numpy as np

def similarity(embedded, w, b, N=4, M=5, P=64, center=None):
    """ Calculate similarity matrix from embedded utterance batch (NM x embed_dim) eq. (9)
        Input center to test enrollment. (embedded for verification)
    :return: tf similarity matrix (NM x N)
    embedded_split = tf.reshape(embedded, shape=[N, M, P])

    if center is None:
        center = normalize(tf.reduce_mean(embedded_split, axis=1))              # [N,P] normalized center vectors eq.(1)
        center_except = normalize(tf.reshape(tf.reduce_sum(embedded_split, axis=1, keep_dims=True)
                                             - embedded_split, shape=[N*M,P]))  # [NM,P] center vectors eq.(8)
        # make similarity matrix eq.(9)
        S = tf.concat(
            [tf.concat([tf.reduce_sum(center_except[i*M:(i+1)*M,:]*embedded_split[j,:,:], axis=1, keep_dims=True) if i==j
                        else tf.reduce_sum(center[i:(i+1),:]*embedded_split[j,:,:], axis=1, keep_dims=True) for i in range(N)],
                       axis=1) for j in range(N)], axis=0)
    else :
        # If center(enrollment) exist, use it.
        S = tf.concat(
            [tf.concat([tf.reduce_sum(center[i:(i + 1), :] * embedded_split[j, :, :], axis=1, keep_dims=True) for i
                        in range(N)],
                       axis=1) for j in range(N)], axis=0)

    S = tf.abs(w)*S+b   # rescaling

    return S

def loss_cal(S, type="softmax", N=4, M=5):
    """ calculate loss with similarity matrix(S) eq.(6) (7) 
    :type: "softmax" or "contrast"
    :return: loss
    S_correct = tf.concat([S[i*M:(i+1)*M, i:(i+1)] for i in range(N)], axis=0)  # colored entries in Fig.1

    if type == "softmax":
        total = -tf.reduce_sum(S_correct-tf.log(tf.reduce_sum(tf.exp(S), axis=1, keep_dims=True) + 1e-6))
    elif type == "contrast":
        S_sig = tf.sigmoid(S)
        S_sig = tf.concat([tf.concat([0*S_sig[i*M:(i+1)*M, j:(j+1)] if i==j
                              else S_sig[i*M:(i+1)*M, j:(j+1)] for j in range(N)], axis=1)
                             for i in range(N)], axis=0)
        total = tf.reduce_sum(1-tf.sigmoid(S_correct)+tf.reduce_max(S_sig, axis=1, keep_dims=True))
        raise AssertionError("loss type should be softmax or contrast !")

    return total

Here loss_cal() function is our ge2e loss, which contains two type: softmax and contrast.

Then we can evaluate this loss function by code below:

if __name__ == "__main__":
    w= tf.constant([1], dtype= tf.float32)
    b= tf.constant([0], dtype= tf.float32)
    embedded = tf.constant([[0,1,0],[0,0,1], [0,1,0], [0,1,0], [1,0,0], [1,0,0]], dtype= tf.float32)
    sim_matrix = similarity(embedded,w,b,3,2,3)
    loss1 = loss_cal(sim_matrix, type="softmax",N=3,M=2)
    loss2 = loss_cal(sim_matrix, type="contrast",N=3,M=2)
    with tf.Session() as sess:

Here the input is 6*3,  and N = 3, M = 2 and P = 3.

Run this code, we will get this result:

[[0.         0.9999995  0.        ]
 [0.         0.         0.        ]
 [0.70710605 0.9999995  0.        ]
 [0.70710605 0.9999995  0.        ]
 [0.         0.         0.9999995 ]
 [0.         0.         0.9999995 ]]

How to train GE2E loss?

From paper << GENERALIZED END-TO-END LOSS FOR SPEAKER VERIFICATION>>, to train GE2E loss, we should:

N = 64, M = 10 utterances per speaker.

learning rate = 0.01 and decrease it by half every 30M steps.

Change Learning Rate By Global Step When Training Model in TensorFlow – TensorFlow Tutorial

The L2-norm of gradient is clipped at 3

Step Guide to Apply Gradient Clipping in TensorFlow – TensorFlow Tutorial

the gradient scale for projection node in LSTM is set to 0.5