TensorFlow tf.map_fn() method can allow us to call a function for each element in a tensor on axis = 0. In this tutorial, we will use some simple examples to help you understand and use this function.

## Syntax

tf.map_fn( fn, elems, dtype=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, name=None )

## Parameters explained

fn: the function which be called. Its parameter is the each element in elems on axis = 0.

elems: a tensor, the elements will be passed into fn on axis = 0.

back_prop: support for back propagation or not, which is very helpful when building deep learning model.

We will use some simple examples to show how to use this function.

## Create two tensors

import tensorflow as tf import numpy as np x = tf.Variable(np.array([[1, 2, 2, 1],[2, 1, 3, 4], [4, 3, 1, 1]]), dtype = tf.int32) z = tf.Variable(np.array([1, 2, 2, 1]), dtype = tf.int32)

We have created two tensors, x and z. We will use each element in x on axi = 0 to multiply z.

## Use tf.map_fn

def integrate(ix): print(ix) x1 = tf.multiply(ix, z) return x1 xx = tf.map_fn(integrate, x) print(xx)

Here we have created a function named integrate, which will use ix to multiply z. ix is the element in x on axis = 0.

## Print the result

init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) print(sess.run(xx))

Run this code, we will get the result:

Tensor("map/while/TensorArrayReadV3:0", shape=(4,), dtype=int32) Tensor("map/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4), dtype=int32) [[1 4 4 1] [2 2 6 4] [4 6 2 1]]

From the result, we can find:

1.The shape of x is (3, 4). However, the shape of ix is (4, ), which means you have to reshape the shape of ix when operating it.

2.The shape of result xx is (3, 4) , the shape of it you also to notice.

The operation is below:

If you want to use lambda, you also can do like this:

xx = tf.map_fn(lambda ix:tf.multiply(ix, z), x)

or

xx = tf.map_fn(lambda ix:integrate(ix), x)

To understand lambda, you can read:

Understand Python Lambda Function for Beginners – Python Tutorial