TensorFlow Remove NaN in Tensor: A Beginner Guide

By | January 18, 2020

When you are using tensorflow to train some deep learning models, you may get some tensors which contains NaN value. For example, tf.svd() function may cause NaN problem. How to remove these NaN values in tensor? We will use a simple example to show you how to do.

Here is an example:

import tensorflow as tf
import numpy as np

x = tf.Variable(np.array([[1, -1.2, 1.3],[1, 1.5, -1.2]]), dtype = tf.float32)
z = tf.log(x)

init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()

with tf.Session() as sess:
    sess.run([init, init_local])

Run this code, we can find z is:

[array([[0.        ,        nan, 0.26236424],
       [0.        , 0.4054651 ,        nan]], dtype=float32)]

There are two nan values in tensor z, we should remove them.

tensorflow has nan value in tensor

How to remove nan value in tensorflow tensor?

You can use a valid number, for example 0, to replace nan value in tensor.

We have created a function to replace nan to 0, here is an example:

def replacenan(t):
    return tf.where(tf.is_nan(t), tf.zeros_like(t), t)

Then we can change the example above to:

z = tf.log(x)
def replacenan(t):
    return tf.where(tf.is_nan(t), tf.zeros_like(t), t)
z = replacenan(z)

Run this code again, we can find the value of z is:

[array([[0.        , 0.        , 0.26236424],
       [0.        , 0.4054651 , 0.        ]], dtype=float32)]

We have removed all nan values in tensor z.

Leave a Reply

Your email address will not be published. Required fields are marked *