Understand numpy.array_split(): Split an Array into Multiple Sub-arrays – NumPy Tutorial

By | July 4, 2022

In this tutorial, we will introduce numpy.array_split() function, which can split an array into multiple sub-arrays. Here, we will use some examples to show you how to use it correctly.

numpy.array_split()

It is defined as:

numpy.array_split(ary, indices_or_sections, axis=0)

We can find numpy.array_split() is similar to numpy.split(). However, numpy.array_split() can allows indices_or_sections to be an integer that does not equally divide the axis.

To understand how to use numpy.split(), you can view:

Understand numpy.split(): Split an Array into Sub-Arrays – NumPy Tutorial

Parameter explained

ary: an array we plan to split

indices_or_sections: int or 1-D array, which determines how to split an array

axis: split an array on which axis

How to use numpy.array_split()?

When indices_or_sections = integer

Here is an example:

import numpy as np

data = np.arange(12)
data = np.reshape(data, [3,4])
print(data)

sub_data = np.array_split(data, 2, axis = 0)
print(sub_data)

Here data contains 3 elements on axis = 0, we will split it to 2 sub-arrays.

Run this code, we will see:

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
[array([[0, 1, 2, 3],
       [4, 5, 6, 7]]), array([[ 8,  9, 10, 11]])]

We can find the first sub-array contains 2 elements and the second sub-array contains 1 element.

However, if we use numpy.split(), how about the result?

sub_data = np.split(data, 2, axis = 0)
print(sub_data)

Run this code, we will see: ValueError: array split does not result in an equal division

Because there are only 3 elements, 3%2 != 0

When indices_or_sections  = list

For example:

indices_or_sections = [1, 3]
sub_data = np.array_split(data, indices_or_sections, axis = 1)
print(sub_data)

Run this code, we will see:

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
[array([[0],
       [4],
       [8]]), array([[ 1,  2],
       [ 5,  6],
       [ 9, 10]]), array([[ 3],
       [ 7],
       [11]])]

Here if we use numpy.split(), how about the result?

indices_or_sections = [1, 3]
sub_data = np.split(data, indices_or_sections, axis = 1)
print(sub_data)

Run this code, we will see:

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
[array([[0],
       [4],
       [8]]), array([[ 1,  2],
       [ 5,  6],
       [ 9, 10]]), array([[ 3],
       [ 7],
       [11]])]

When indices_or_sections  is a list, the result of numpy.array_split() and numpy.split() are the same.