TensorFlow tf.tensordot() is a powerful function to multiply tensors. It allows us to multiply different rank tensors. In this tutorial, we will use some examples to show you how to use this function.

## tf.tensordot()

It is defined as:

tf.tensordot( a, b, axes, name=None )

Here we should notice:

- axes can be an integer, a list or tuple list.
- axes can be devided to [a_axes, b_axes], which means axes = [a_axes, b_axes].
- We use a_axes to select the value in tensor a, use b_axes to select the value in tensor b.

For example:

y = tf.tensordot(a, b, 1) y = tf.tensordot(a, b, [1,1]) y = tf.tensordot(a, b, [(1,2), (2, 10])

However,* if axes is an integer, it shoud be bigger than 0*.

## How to multiply tensors in tf.tensordot()?

We will use some examples to explain.

## If axes is an integer

Look at this example:

import tensorflow as tf a = tf.ones(shape=[5,4,2,3]) b = tf.ones(shape=[3,2,6]) c = tf.tensordot(a,b, axes=1) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) #print(sess.run(c)) print(sess.run(tf.shape(c)))

Here axes = 1, it is an integer.

## How to get the shape of result c?

Look at tf.tensordot() source code:

https://github.com/tensorflow/tensorflow/blob/23c218785eac5bfe737eec4f8081fd0ef8e0684d/tensorflow/python/ops/math_ops.py#L2899

The shape of tensor c is computed based on a_axes and b_axes.

Because axes = 1, which is an integer. From the source code, we can find the shape of tensor c is:

product = array_ops.reshape( ab_matmul, array_ops.concat([a_free_dims, b_free_dims], 0), name=name)

## How to compute a_free_dims and b_free_dims?

Here axes = 1, it is an integer.

Look at the source code:

a_axes, b_axes = _tensordot_axes(a, axes)

Here

In this example:

The shape of a is [5, 4, 2, 3]

rank_a = 4

a_axes = range(4-1,4) = [3]

b_axes = range(1) = [0]

In order to get a_free_dims and b_free_dims,

Look at source code:

In this example:

The shape of a is [5, 4, 2, 3]

rank_a = 4, rank_b = 3

a_axes = [3]

a_free = setdiff1d([0, 1, 2, 3],[3]) = [0, 1, 2]

a_free_dims = [5, 4, 2]

b_free = setdiff1d([0, 1, 2], [0]) = [1,2]

b_free_dims = [2, 6]

The shape of c = concat([5, 4, 2], [2, 6]) = [5, 4, 2, 2, 6]

## If axes = [a_axes, b_axes]

For example:

axes = [1, 1]

axes will be converted to:

axes = [[1], [1]]

Here is the source code:

a_axes = axes[0] b_axes = axes[1] if isinstance(a_axes, compat.integral_types) and \ isinstance(b_axes, compat.integral_types): a_axes = [a_axes] b_axes = [b_axes]

Here a_axes and b_axes is list.

We must make sure len(a_axes) = len(a_axes).

Here is the source code:

## How to compute tensor c when axes is list?

Different from axes is an integer. when axes is list. The shape of tensor c is:

a_free_dims + b_free_dims

In order to compute c, there are three main steps.

Look at the source code:

*Step 1:*

tensor a shape: [5, 4, 2, 3]

a_axes = [1]

a_free = [i for i in xrange(4) if i not in [1]] = [0, 2, 3]

a_free_dims = [5, 2, 3]

Similar to a_free_dims, the b_free_dims = [3, 6]

*Step 2:*

The tensor a is converted to new_shape (5*2*3) * 4 = 30*4

The tensor b is converted to new_shape 2* (3*6)* 2 = 2* 18

Here 2≠ 4, it will be wrong.

For example:

import tensorflow as tf a = tf.ones(shape=[5,4,2,3]) r = tf.rank(a) b = tf.ones(shape=[3,2,6]) c = tf.tensordot(a,b, axes=[1,1]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(r)) print(sess.run(tf.shape(c)))

Run this code, you will get an error:

ValueError: Dimensions must be equal, but are 4 and 2 for ‘Tensordot/MatMul’ (op: ‘MatMul’) with input shapes: [30,4], [2,18].

If b =

b = tf.ones(shape=[3,4,6])

We will get a resultable tensor c, the shape fo it is 30*18

*Step 3:*

We shoud reshape tensor c.

The shape fo tensor c is:

a_free_dims + b_free_dims

In this example, it is:

[5, 2, 3] + [3, 6] = [5, 2, 3, 3, 6]

Look at more exmaple:

Example 1:

a = tf.ones(shape=[5,4,2,3]) b = tf.ones(shape=[3,4,6]) c = tf.tensordot(a,b, axes=[[0, 2],[1, 2]])

Here a_axes = [0, 2], tensor a can be converted to (4*3) * (5*2) = 12 * 10 shape.

b_axes = [1, 2], tensor b can be converted to (4*6) * 3 = 24 *3.

10 ≠ 24. It will report an error:

ValueError: Dimensions must be equal, but are 10 and 24 for ‘Tensordot/MatMul’ (op: ‘MatMul’) with input shapes: [12,10], [24,3].

Example 2:

a = tf.ones(shape=[5,4,2,3]) b = tf.ones(shape=[3,5,2]) c = tf.tensordot(a,b, axes=[[0, 2],[1, 2]])

Here a is converted to 12*10, b is converted to 10*3. Then, we will get a tensor c with the shape 12* 3.

a free dims is [4, 3], b free dims is [3]

Tensor c will be reshaped to [4, 3, 3].