Understand and Read TensorFlow MNIST Dataset for Beginners – TensorFlow Tutorial

By | November 10, 2019

MNIST dataset is a handwritten digits images and  common used in tensorflow applications. In this tutorial, we will discuss this dataset for tensorflow beginners in order to help them to use it correctly.

tensorflow mnist dataset

Data in MNIST dataset

MNIST dataset contains three parts:

Train data (mnist.train): It contains 55000 images data and lables.

We can use train data to train our model.

Validation data (mnist.validation): It contains 5000 images and labels.

We can use this data to adjust our hyperparameters in our model.

Test data (mnist.test): It contains 10000 images and labels.

We can use test data to validate our effect of our model.

We can find:

train : validation : test = 55000 : 5000 : 10000 = 11 : 1 : 2

How to read mnist dataset in tensorflow?

We can use input_data() function to load, here is an example:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

# Get and load Mnist Data
mnist = input_data.read_data_sets(os.getcwd() + "/MNIST-data/", one_hot= True)

Run this script, we will find mnist dataset in MNIST-data folder. It contains four files.

tensorflow mnist dataset files

Print mnist.train, mnist.validation and mnist.test

print("mnist train data")
print(mnist.train)
print("mnist validation data")
print(mnist.validation)
print("mnist test data")
print(mnist.test)

The result is:

mnist train data
<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x000001F40B10FD68>
mnist validation data
<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x000001F413A5CF28>
mnist test data
<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x000001F413A5CF60>

From the result, we can find mnist.train, mnist.validation and mnist.test is tensorflow DataSet object.

Check the data type and dimension of mnist trian, validation  and test images and labels

As to mnist train data, we print some information on train images and labels data.

mnist_train_images = mnist.train.images
print("mnist train images")
print(type(mnist_train_images))
print(mnist_train_images.shape)
print(mnist_train_images)
mnist_train_labels = mnist.train.labels
print("mnist train labels")
print(type(mnist_train_labels))
print(mnist_train_labels.shape)
print(mnist_train_labels)

From the result, we can find:

tensorflow mnist train data images and label data information

1. We can use mnist.train.images to get images data and mnist.train.labels to get image labels data.

2.The data type of images and labels in mnist.train is numpy.ndarry

3.The shape of mnist train images is: 55000 * 784, which means mnist.train contains 55000 images and 55000 labels.

4.Each image data is 1*784 and each label is 1*10

As to mnist validation and test data, we also can print them with the same way.

Read mnist train/validation/test batch data

In tensorflow, we often read batch data to train, validate or test our model. To read batch data, we can use next_batch(batch_num) function.

As to read 64 batch test data.

test_images_batch, test_labels_batch = mnist.test.next_batch(64)

print(type(test_images_batch))
print(test_images_batch.shape)
print(type(test_labels_batch))
print(test_labels_batch.shape)

The print result is:

<class 'numpy.ndarray'>
(64, 784)
<class 'numpy.ndarray'>
(64, 10)

From the result, we can find we have read 64 * 784 image data, which contains 64 images, meanwhile, we also read 64 * 10 labels data, which contains 64 labels.