Step Guide to Load MNIST Dataset for Training in PyTorch – PyTorch Tutorial

By | May 13, 2022

In this tutorial, we will introduce how to load mnist dataset for training using pytorch. It is very useful for pytorch beginners.

Preliminary

We can use torchvision to load mnist dataset in pytorch. It has created a MNIST class for us to load data.

CLASStorchvision.datasets.MNIST(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

Here:

root: string – Root directory of dataset where MNIST/raw/train-images-idx3-ubyte and MNIST/raw/t10k-images-idx3-ubyte exist.

train: True, creates dataset from train-images-idx3-ubyte, otherwise from t10k-images-idx3-ubyte

transform and target_transform: it can determine how to return image and target.

In MNIST class, we can find how these two callable function to work.

transform function in pytorch MNIST dataset

In pytorch, we usually transform image and target to tensors.

Use torchvision to load mnist data

We will import some libraries and start to load.

Here is the example code:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_dt = datasets.MNIST(
    root = 'data',
    train = True,
    transform = transforms.ToTensor(),
    download = True,
)
test_dt = datasets.MNIST(
    root = 'data',
    train = False,
    transform = transforms.ToTensor(),
    download= True
)

print(type(train_dt))
print(type(test_dt))

Run this code, we will see:

pytorch download MNIST dataset

Here we use transforms.ToTensor() to transform images and targets to pytorch tensors.

Then we can find a data directory as follows:

save mnist dataset in pytorch

Then, we can start to iterate all images.

Read mnist images and labels with batch size

We will create a DataLoader to read train and test set.

Here is an example:

batch_size = 32
gen_train = DataLoader(dataset=train_dt, batch_size=batch_size, shuffle=True, num_workers=0)
for iteration, batch in enumerate(gen_train):
    print(iteration, type(batch))
    print(len(batch))
    print(batch[0].shape, batch[1].shape)

In this example code, we will iterate images in train set with batch_size = 32.

Run this code, we will see:

0 <class 'list'>
2
torch.Size([32, 1, 28, 28]) torch.Size([32])
1 <class 'list'>
2
torch.Size([32, 1, 28, 28]) torch.Size([32])
2 <class 'list'>
2
torch.Size([32, 1, 28, 28]) torch.Size([32])

Then we can start to use train set to train your pytorch model.