How to Use tf.tensordot(): A Completed Guide – TensorFlow Tutorial

By | April 1, 2021

TensorFlow tf.tensordot() is a powerful function to multiply tensors. It allows us to multiply different rank tensors. In this tutorial, we will use some examples to show you how to use this function.

tf.tensordot()

It is defined as:

tf.tensordot(
    a,
    b,
    axes,
    name=None
)

Here we should notice:

  • axes can be an integer, a list or tuple list.
  • axes can be devided to [a_axes, b_axes], which means axes = [a_axes, b_axes].
  • We use a_axes to select the value in tensor a, use b_axes to select the value in tensor b.

For example:

y = tf.tensordot(a, b, 1)
y = tf.tensordot(a, b, [1,1])
y = tf.tensordot(a, b, [(1,2), (2, 10])

However, if axes is an integer, it shoud be bigger than 0.

How to multiply tensors in tf.tensordot()?

We will use some examples to explain.

If axes is an integer

Look at this example:

import tensorflow as tf
a = tf.ones(shape=[5,4,2,3])
b = tf.ones(shape=[3,2,6])
c = tf.tensordot(a,b, axes=1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #print(sess.run(c))
    print(sess.run(tf.shape(c)))

Here axes = 1, it is an integer.

How to get the shape of result c?

Look at tf.tensordot() source code:

https://github.com/tensorflow/tensorflow/blob/23c218785eac5bfe737eec4f8081fd0ef8e0684d/tensorflow/python/ops/math_ops.py#L2899

The shape of tensor c is computed based on a_axes and b_axes.

Because axes = 1, which is an integer. From the source code, we can find the shape of tensor c is:

product = array_ops.reshape(
          ab_matmul, array_ops.concat([a_free_dims, b_free_dims], 0), name=name)

tensorflow tf.tensordot() tensor a and b free dims

How to compute a_free_dims and b_free_dims?

Here axes = 1, it is an integer.

Look at the source code:

a_axes, b_axes = _tensordot_axes(a, axes)

Here

tensorflow tf.tensordot() get tensor a and b axes

In this example:

The shape of a is [5, 4, 2, 3]

rank_a = 4

a_axes = range(4-1,4) = [3]

b_axes = range(1) = [0]

In order to get a_free_dims and b_free_dims,

Look at source code:

tensorflow tf.tensordot() get tensor free dims when axes is integer

In this example:

The shape of a is [5, 4, 2, 3]

rank_a = 4, rank_b = 3

a_axes = [3]

a_free = setdiff1d([0, 1, 2, 3],[3]) = [0, 1, 2]

a_free_dims = [5, 4, 2]

b_free = setdiff1d([0, 1, 2], [0]) = [1,2]

b_free_dims = [2, 6]

The shape of c = concat([5,  4,  2], [2,  6]) = [5, 4, 2, 2, 6]

If axes = [a_axes, b_axes]

For example:

axes = [1, 1]

axes will be converted to:

axes = [[1], [1]]

Here is the source code:

      a_axes = axes[0]
      b_axes = axes[1]
      if isinstance(a_axes, compat.integral_types) and \
          isinstance(b_axes, compat.integral_types):
        a_axes = [a_axes]
        b_axes = [b_axes]

Here a_axes and b_axes is list.

We must make sure len(a_axes) = len(a_axes).

Here is the source code:

tensorflow tf.tensordot() axes is list

How to compute tensor c when axes is list?

Different from axes is an integer. when axes is list. The shape of tensor c is:

a_free_dims + b_free_dims

tensorflow tf.tensordot() the tensor c shape when axes is list

In order to compute c, there are three main steps.

Look at the source code:

tensorflow tf.tensordot() the tensor a and b free dims shape when axes is list

Step 1:

tensor a shape: [5, 4, 2, 3]

a_axes = [1]

a_free = [i for i in xrange(4) if i not in [1]] = [0, 2, 3]

a_free_dims = [5, 2, 3]

Similar to a_free_dims, the b_free_dims = [3, 6]

Step 2:

The tensor a is converted to new_shape (5*2*3) * 4 = 30*4

The tensor b is converted to new_shape 2* (3*6)* 2 = 2* 18

Here 2≠ 4, it will be wrong.

For example:

import tensorflow as tf
a = tf.ones(shape=[5,4,2,3])
r = tf.rank(a)
b = tf.ones(shape=[3,2,6])
c = tf.tensordot(a,b, axes=[1,1])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(r))
    print(sess.run(tf.shape(c)))

Run this code, you will get an error:

ValueError: Dimensions must be equal, but are 4 and 2 for ‘Tensordot/MatMul’ (op: ‘MatMul’) with input shapes: [30,4], [2,18].

If b =

b = tf.ones(shape=[3,4,6])

We will get a resultable tensor c, the shape fo it is 30*18

Step 3:

We shoud reshape tensor c.

The shape fo tensor c is:

a_free_dims + b_free_dims

In this example, it is:

[5, 2, 3] + [3, 6] = [5, 2, 3, 3, 6]

Look at more exmaple:

Example 1:

a = tf.ones(shape=[5,4,2,3])
b = tf.ones(shape=[3,4,6])
c = tf.tensordot(a,b, axes=[[0, 2],[1, 2]])

Here a_axes = [0, 2], tensor a can be converted to (4*3) * (5*2) = 12 * 10 shape.

b_axes = [1, 2], tensor b can be converted to (4*6) * 3 = 24 *3.

10 ≠ 24. It will report an error:

ValueError: Dimensions must be equal, but are 10 and 24 for ‘Tensordot/MatMul’ (op: ‘MatMul’) with input shapes: [12,10], [24,3].

Example 2:

a = tf.ones(shape=[5,4,2,3])
b = tf.ones(shape=[3,5,2])
c = tf.tensordot(a,b, axes=[[0, 2],[1, 2]])

Here a is converted to 12*10, b is converted to 10*3. Then, we will get a tensor c with the shape 12* 3.

a free dims is [4, 3], b free dims is [3]

Tensor c will be reshaped to [4, 3, 3].