Understand TensorFlow tf.where() with Examples – TensorFlow Tutorial

By | July 20, 2020

TensorFlow tf.where() function can help us to select tensor by condition. In this tutorial, we will discuss how to use this function correctly with some examples.

Syntax

tf.where(
    condition,
    x=None,
    y=None,
    name=None
)

Return the elements, either from x or y, depending on the condition. If element in condition is True, this function will return element in x at the same position, otherwise, it will return element in y.

Parameter explained

condition: A Tensor of type bool, we shoud select elements in x or y based on it.

x: A Tensor which may have the same shape as condition

y: A Tensor which may have the same shape as x

To use this function correctly, we should notice:

  • condition is a tensor of type bool
  • x and y should have the same shape with condition

We will use some examples to explain them.

tf.where() example

Create condition, x and y tensor

import tensorflow as tf
import numpy as np

condition = tf.Variable(np.array([[True, False, False],[False, True, False],[True, True, True]]), dtype = tf.bool, name = 'condition')
x = tf.Variable(np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]), dtype = tf.float32, name = 'x')
y =tf.Variable(np.array([[11, 12, 13],[14, 15, 16],[17, 18, 19]]), dtype = tf.float32, name = 'y')

Here we can find condtion is a bool type. condtion, x and y tensor have the same shape.

r = tf.where(condition, x, y)

The r tensor will be:

[array([[ 1., 12., 13.],
       [14.,  5., 16.],
       [ 7.,  8.,  9.]], dtype=float32)]

From the result we can find:

If element in condtion is True, r will save element in x at the same postion, otherwise it will save element in y. It explains why the shape of condition, x and y should be the same.

If condition is:

condition = tf.Variable(np.array([True, False, True]), dtype = tf.bool, name = 'condition')

r will be:

[[ 1.  2.  3.]
 [14. 15. 16.]
 [ 7.  8.  9.]]

Leave a Reply

Your email address will not be published. Required fields are marked *