# tf.map_fn(): Processing Multiple Input and Output Tensors – TensorFlow Tutorial

By | January 25, 2022

TensorFlow tf.map_fn() method can allow us to use a function to process each element in a tensor on axis = 0 and return a tensor. Here is a tutorial:

Understand TensorFlow tf.map_fn(): A Beginner Guide – TensorFlow Tutorial

However, this function also can process multiple input tensors and return multiple tensors. In this tutorial, we will discuss this topic.

## How to process multiple input tensors in tf.map_fn()

In order to make tf.map_fun() to process multiple input tensors, we should pack them in a tuple.

For example:

import tensorflow as tf
import numpy as np
data = np.array([[1,2], [4,5]], dtype= np.float)
v1 = tf.convert_to_tensor(data, dtype = tf.float32)

data = np.array([[2,2], [4,4]], dtype= np.float)
v2 = tf.convert_to_tensor(data, dtype = tf.float32)

def xa(x):
#x[0]: v1
print(x[0])
print(x[1])
#x[1]: v2
return x[0]+x[1]
t = tf.map_fn(xa, (v1, v2), dtype = tf.float32)

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
sess.run([init, init_local])
np.set_printoptions(precision=4, suppress=True)
_w = sess.run(t)
print(_w)

In this example, we will pass two tensors into tf.map_fn(). They are v1 and v2.

We will pack them to (v1, v2) and process each element in xa() function.

As to xa() function, we will use x[0] and x[1] to get each element of v1 and v2. It means:

(v1, v2)
x[0], x[1]

Run this code, we will get:

Tensor("map/while/TensorArrayReadV3:0", shape=(2,), dtype=float32)
[[3. 4.]
[8. 9.]]

However, if you get ValueError: The two structures don’t have the same nested structure. You can read this tutorial to fix:

Fix tf.map_fn() ValueError: The two structures don’t have the same nested structure – TensorFlow Tutorial

## How to return multiple tensors in tf.map_fn()

tf.map_fn() also can return multiple tensors. For example:

import tensorflow as tf
import numpy as np
data = np.array([[1,2], [4,5]], dtype= np.float)
v1 = tf.convert_to_tensor(data, dtype = tf.float32)

data = np.array([[2,2], [4,4]], dtype= np.float)
v2 = tf.convert_to_tensor(data, dtype = tf.float32)

def xa(x):
#x[0]: v1
print(x[0])
print(x[1])
#x[1]: v2
return x[0]+x[1],tf.reduce_sum(x[0]),x[0]*x[1]
t1, t2, t3 = tf.map_fn(xa, (v1, v2), dtype = (tf.float32, tf.float32, tf.float32))

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
sess.run([init, init_local])
np.set_printoptions(precision=4, suppress=True)
_w = sess.run([t1, t2, t3])
print(_w)

In this example, we will return three tensors: t1, t2, t3. We should set dtype = (tf.float32, tf.float32, tf.float32) for each output tensor.

Run this code, we will get:

[array([[3., 4.],
[8., 9.]], dtype=float32), array([3., 9.], dtype=float32), array([[ 2.,  4.],
[16., 20.]], dtype=float32)]