Create a Custom Dataset for Loading Data in PyTorch – PyTorch Tutorial

By | April 15, 2022

If you plan to use pytorch to load data to train a model, you should use a Dataset class to load. In this tutorial, we will introduce you how to create a custom Dataset class for loading.

How to create a custom Dataset?

Here is the basic structure of a custom Dataset class.

from torch.utils.data import dataset

class CustomDataset(dataset.Dataset):

    def __init__(self):
        super(CustomDataset, self).__init__()

    def __getitem__(self, index):
        return None

    def __len__(self):
        return 0

Here CustomDataset is our custom Dataset class, it is a child class pytorch dataset.Dataset.

We should overwrite three basic methods: __init__(), __getitem__() and __len__().

Here:

__getiitem__(): we will get a row by index

__len__(): it means how many rows in our dataset.

For example:

class CustomDataset(dataset.Dataset):

    def __init__(self):
        super(CustomDataset, self).__init__()
        # load all data for training or test
        self.all_data = [i for i in range(0, 100)]

    def __getitem__(self, index):
        return self.all_data[index], 2* self.all_data[index]

    def __len__(self):
        return len(self.all_data)

Then we can start to get batch samples from this dataset.

How to load batch data from custom Dataset?

We can use pytorch dataloader.DataLoader().  It is defined as:

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)

It allows us to load batch data from a dataset easily.

For example:

from torch.utils.data import dataloader

train_dataset = CustomDataset()

train_loader = dataloader.DataLoader(
    dataset=train_dataset,
    batch_size=8,
    shuffle=True
    )

print(train_loader)

for i_batch ,batch_data in enumerate(train_loader):
    print(i_batch, batch_data, type(batch_data))
    print("batch end")

Run this code, we will see:

<torch.utils.data.dataloader.DataLoader object at 0x000002AB31EA9B70>
0 [tensor([85, 62, 12, 35, 67, 52, 60,  0]), tensor([170, 124,  24,  70, 134, 104, 120,   0])] <class 'list'>
batch end
1 [tensor([ 7, 21, 99, 41, 32, 23, 82, 45]), tensor([ 14,  42, 198,  82,  64,  46, 164,  90])] <class 'list'>
batch end
2 [tensor([34,  1, 36, 43, 78, 10, 56, 98]), tensor([ 68,   2,  72,  86, 156,  20, 112, 196])] <class 'list'>
batch end
3 [tensor([ 8, 92, 46, 44, 37, 33, 91, 19]), tensor([ 16, 184,  92,  88,  74,  66, 182,  38])] <class 'list'>
batch end
4 [tensor([ 6, 65, 81, 47, 17,  9, 29, 39]), tensor([ 12, 130, 162,  94,  34,  18,  58,  78])] <class 'list'>
batch end
5 [tensor([24, 30, 27, 28, 18,  4, 40, 51]), tensor([ 48,  60,  54,  56,  36,   8,  80, 102])] <class 'list'>
batch end
6 [tensor([16, 57, 93, 54, 22, 48, 71, 38]), tensor([ 32, 114, 186, 108,  44,  96, 142,  76])] <class 'list'>
batch end
7 [tensor([50, 58, 20, 59, 88, 55, 69, 25]), tensor([100, 116,  40, 118, 176, 110, 138,  50])] <class 'list'>
batch end
8 [tensor([72, 76, 90, 73, 53, 42, 63, 70]), tensor([144, 152, 180, 146, 106,  84, 126, 140])] <class 'list'>
batch end
9 [tensor([83, 96, 66, 75,  5, 77, 49, 61]), tensor([166, 192, 132, 150,  10, 154,  98, 122])] <class 'list'>
batch end
10 [tensor([94, 79, 68, 26, 31,  2, 74, 14]), tensor([188, 158, 136,  52,  62,   4, 148,  28])] <class 'list'>
batch end
11 [tensor([89, 97, 64, 11, 15, 84, 13,  3]), tensor([178, 194, 128,  22,  30, 168,  26,   6])] <class 'list'>
batch end
12 [tensor([95, 80, 87, 86]), tensor([190, 160, 174, 172])] <class 'list'>
batch end

From above, we can find:

Each batch is saved in a python list, the data type is pytorch tensor.