# Understand TensorFlow tf.map_fn(): A Beginner Guide – TensorFlow Tutorial

By | May 11, 2020

TensorFlow tf.map_fn() method can allow us to call a function for each element in a tensor on axis = 0. In this tutorial, we will use some simple examples to help you understand and use this function.

## Syntax

tf.map_fn(
fn,
elems,
dtype=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)

## Parameters explained

fn: the function which be called. Its parameter is the each element in elems on axis = 0.

elems: a tensor, the elements will be passed into fn on axis = 0.

back_prop: support for back propagation or not, which is very helpful when building deep learning model.

We will use some simple examples to show how to use this function.

## Create two tensors

import tensorflow as tf
import numpy as np

x = tf.Variable(np.array([[1, 2, 2, 1],[2, 1, 3, 4], [4, 3, 1, 1]]), dtype = tf.int32)
z = tf.Variable(np.array([1, 2, 2, 1]), dtype = tf.int32)

We have created two tensors, x and z. We will use each element in x on axi = 0 to multiply z.

## Use tf.map_fn

def integrate(ix):
print(ix)
x1 = tf.multiply(ix, z)
return x1

xx = tf.map_fn(integrate, x)
print(xx)

Here we have created a function named integrate, which will use ix to multiply z. ix is the element in x on axis = 0.

## Print the result

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()

with tf.Session() as sess:
sess.run([init, init_local])
print(sess.run(xx))

Run this code, we will get the result:

Tensor("map/while/TensorArrayReadV3:0", shape=(4,), dtype=int32)
Tensor("map/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4), dtype=int32)
[[1 4 4 1]
[2 2 6 4]
[4 6 2 1]]

From the result, we can find:

1.The shape of x is (3, 4). However, the shape of ix is (4, ), which means you have to reshape the shape of ix when operating it.

2.The shape of result xx is (3, 4) , the shape of it you also to notice.

The operation is below:

If you want to use lambda, you also can do like this:

xx = tf.map_fn(lambda ix:tf.multiply(ix, z), x)

or

xx = tf.map_fn(lambda ix:integrate(ix), x)

To understand lambda, you can read:

Understand Python Lambda Function for Beginners – Python Tutorial