Understand TensorFlow tf.map_fn(): A Beginner Guide – TensorFlow Tutorial

By | May 11, 2020

TensorFlow tf.map_fn() method can allow us to call a function for each element in a tensor on axis = 0. In this tutorial, we will use some simple examples to help you understand and use this function.

Syntax

tf.map_fn(
    fn,
    elems,
    dtype=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)

Parameters explained

fn: the function which be called. Its parameter is the each element in elems on axis = 0.

elems: a tensor, the elements will be passed into fn on axis = 0.

back_prop: support for back propagation or not, which is very helpful when building deep learning model.

We will use some simple examples to show how to use this function.

Create two tensors

import tensorflow as tf
import numpy as np

x = tf.Variable(np.array([[1, 2, 2, 1],[2, 1, 3, 4], [4, 3, 1, 1]]), dtype = tf.int32)
z = tf.Variable(np.array([1, 2, 2, 1]), dtype = tf.int32)

We have created two tensors, x and z. We will use each element in x on axi = 0 to multiply z.

Use tf.map_fn

def integrate(ix):
    print(ix)
    x1 = tf.multiply(ix, z)
    return x1

xx = tf.map_fn(integrate, x)
print(xx)

Here we have created a function named integrate, which will use ix to multiply z. ix is the element in x on axis = 0.

Print the result

init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()

with tf.Session() as sess:
    sess.run([init, init_local])
    print(sess.run(xx))

Run this code, we will get the result:

Tensor("map/while/TensorArrayReadV3:0", shape=(4,), dtype=int32)
Tensor("map/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4), dtype=int32)
[[1 4 4 1]
 [2 2 6 4]
 [4 6 2 1]]

From the result, we can find:

1.The shape of x is (3, 4). However, the shape of ix is (4, ), which means you have to reshape the shape of ix when operating it.

2.The shape of result xx is (3, 4) , the shape of it you also to notice.

The operation is below:

tensorflow tf.map_fn tutorial and example

If you want to use lambda, you also can do like this:

xx = tf.map_fn(lambda ix:tf.multiply(ix, z), x)

or

xx = tf.map_fn(lambda ix:integrate(ix), x)

To understand lambda, you can read:

Understand Python Lambda Function for Beginners – Python Tutorial