Tensorflow tf.split() function can allow us to split a tensor into some sub tensors, here are some examples:

Split a Tensor to Sub Tensors with tf.split()

In this tutorial, we will discuss some tips on use tf.split(), you can learn how to use this function correctly by our tutorial.

## Syntax of tf.split()

tf.split( value, num_or_size_splits, axis=0, num=None, name='split' )

As to tf.split(), there are some very important parameters you must notice.

## Important parameters

value: a tensor you want to split

num_or_size_splits: this parameter determines the size or shape of each sub tensor, it is often a list, such as [1, 3, 5]

axis: this parameter determines how to split a tensor into sub tensors.

## Return

tf.split() will return a list which contains sub tensors.

Here we will use some examples to explain how to use this function correctly.

## Create a 2 * 3 * 4 shape tensor

#coding=utf-8 import tensorflow as tf w = tf.Variable(tf.random_uniform([2,3,4], -1, 1))

We should notice: there are 2 elements on axis = 0, 3 elements on axis = 1 and 4 elements on axis = 2.

To understand the relation between tensor axis and shape, you can refer to this tutorial.

Understand Tensor Axis and Shape with Examples: A Beginner Guide

## Split a tensor to 2 sub tensors on axis = 0

We know there are only 2 elements on axis = 0, which mean the sum of num_or_size_splits shoud be 2.

sub_w = tf.split(w,num_or_size_splits = [1, 1]) print(type(sub_w)) print(sub_w)

The result is:

<class 'list'> [<tf.Tensor 'split:0' shape=(1, 3, 4) dtype=float32>, <tf.Tensor 'split:1' shape=(1, 3, 4) dtype=float32>]

From the result, we can find:

1.The return variable sub_w, the type of which is python list.

2.There are 2 tensors in sub_w, the shape of each sub tensor is 1* 3 * 4.

If the sum of num_or_size_splits is not equal to 2, how about?

sub_w = tf.split(w,num_or_size_splits = [1, 2])

Then you will get error: *ValueError: Sum of output sizes must match the size of the original Tensor along the split dimension *

## Split a tensor to 2 sub tensors on axis = 1

There are 3 elements on axis = 1, we also be sure that the sum of num_or_size_splits is equal to 3.

sub_w = tf.split(w,num_or_size_splits = [1, 2], axis= 1)

Then you will get two sub tensor, one is 2 * 1 * 4 , the other is 2 * 2 * 4

[<tf.Tensor 'split:0' shape=(2, 1, 4) dtype=float32>, <tf.Tensor 'split:1' shape=(2, 2, 4) dtype=float32>]