Fix TensorFlow tf.where() Gradient NAN Error – TensorFlow Tutorial

By | November 4, 2020

Tensorflow tf.where() function can allow us to filter or remove some values from a tensor.

For example, we can remove nan or inf value from a tensor.

TensorFlow Remove NaN in Tensor: A Beginner Guide

In order to understand how to use tf.where(), you can view:

Understand TensorFlow tf.where() with Examples – TensorFlow Tutorial

However, when we are using tf.where() to filer tensor, we may cause gradient nan error.

In this tutorial, we will discuss this topic and introduce how to fix it.

NaN gradient

For example, tensorflow will report an inf value if we compute 1/0.

import tensorflow as tf 
  
# Initializing the input tensor 
a = tf.constant([6, 8, 0, 15], dtype = tf.float64) 
b = tf.constant([2, 3, 4, 0],  dtype = tf.float64) 
  
# Printing the input tensor 
print('a: ', a) 
print('b: ', b) 
r = a / b
  
# Printing the result 
init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    print(sess.run(r))

Run this code, you will get the r value.

[3.         2.66666667 0.                inf]

There is an inf value in r, we can use tf.where() to remove it.

r = tf.where(tf.is_inf(r), tf.zeros_like(r), r)

Then r will be a normal value.

[3.         2.66666667 0.         0.        ]

Compute the gradient of \(\frac{\partial r}{\partial b}\)

Although the value of r is normal, however, the gradient of it is not.

t = tf.gradients(r, b)
# Printing the result 
init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    print(sess.run(r))
    print(sess.run(t))

Run this code, we will get:

[3.         2.66666667 0.         0.        ]
[array([-1.5       , -0.88888889,  0.        ,         nan])]

The is a nan value in \(\frac{\partial r}{\partial b}\).

Fix TensorFlow tf.where() Gradient NAN Error - TensorFlow Tutorial

How to fix tf.where() gradient nan value?

In order to make sure the gradient is normal, we should be sure the input tensor is normal.

We will create a function to compute \(\frac{a}{b}\)

def safediv(a, b):
    # Calculating safe division
    safe_x = tf.where(tf.not_equal(b, 0.), b, tf.ones_like(b))
    
    resx = tf.where(tf.not_equal(b, 0.), tf.div(x = a, y = safe_x), tf.zeros_like(safe_x))
    return resx

In safediv() function, we have used two tf.where() function. The first one makes the b is normal, which means there is not 0 value in it, the second tf.where() computes the normal value.

Modify example code above.

import tensorflow as tf 
  
# Initializing the input tensor 
a = tf.constant([6, 8, 0, 15], dtype = tf.float64) 
b = tf.constant([2, 3, 4, 0],  dtype = tf.float64) 
  
# Printing the input tensor 
print('a: ', a) 
print('b: ', b) 
def safediv(a, b):
    # Calculating safe division
    safe_x = tf.where(tf.not_equal(b, 0.), b, tf.ones_like(b))
    
    resx = tf.where(tf.not_equal(b, 0.), tf.div(x = a, y = safe_x), tf.zeros_like(safe_x))
    return resx
r = safediv(a, b)

t = tf.gradients(r, b)
  
# Printing the result 
init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    print(sess.run(r))
    print(sess.run(t))

The value r and gradient t is:

[3.         2.66666667 0.         0.        ]
[array([-1.5       , -0.88888889,  0.        ,  0.        ])]

There is not a nan value in gradient.