Understand TensorDataset with Examples in PyTorch – PyTorch Tutorial

By | April 3, 2023

In pytorch, TensorDataset allows us to zip serial of pytorch tensors as python zip() function. In this tutorial, we will use some examples to show you how to use.

Syntax

TensorDataset is defined as:

class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

From __getitem__(), we can find it will return elements in each tensors based on index.

How to use TensorDataset?

For example:

import torch
from torch.utils.data import TensorDataset

x1 = torch.randn(4,5)
x2 = torch.randn(4, 10)

d = TensorDataset(x1, x2)
print(d)
for e in d:
    print(e)

Run this code, we will see:

<torch.utils.data.dataset.TensorDataset object at 0x0000021DBA92FB00>
(tensor([ 0.0644, -0.8627, -0.9599, -0.3772, -0.0840]), tensor([ 2.5645,  0.3732, -0.3954, -1.9667,  0.5432, -0.3737, -0.2884, -0.7295,
        -1.8462,  0.2305]))
(tensor([-0.9827, -1.5631,  0.0772,  0.7499,  0.3318]), tensor([ 0.9658,  1.9472, -0.1003, -1.1146,  0.7413, -1.0945, -0.0801, -0.3975,
        -0.6289, -0.4536]))
(tensor([-0.1613, -0.3813, -0.2677, -1.0164, -0.4861]), tensor([ 0.0522,  1.7411,  0.2216, -1.0339,  0.2794,  1.4683, -1.1677, -0.4825,
        -1.6060, -1.4113]))
(tensor([ 0.1819, -0.1794,  0.3319, -0.0702,  0.6290]), tensor([ 0.4300,  1.3721, -0.5497, -0.1086,  0.6109, -0.7664,  1.2882, -1.6521,
        -0.5760, -0.4642]))

Example 2:

x1 = torch.randn(4, 2)
x2 = torch.randn(4, 3)
x3 = torch.randn(4, 2)

d = TensorDataset(x1, x2, x3)
print(d)
for e in d:
    print(e)

print(x1)

Then, we will see:

(tensor([ 0.2332, -0.4129]), tensor([0.0416, 0.0045, 0.9388]), tensor([0.5951, 0.4067]))
(tensor([ 1.0542, -0.9473]), tensor([-1.2757,  0.0499, -0.7282]), tensor([ 1.0244, -0.4466]))
(tensor([-0.4178,  0.9416]), tensor([-1.0980, -0.2778, -0.0483]), tensor([-2.4482,  1.3482]))
(tensor([ 3.0787, -0.5175]), tensor([ 0.2855,  1.2509, -1.4400]), tensor([ 1.1726, -0.0982]))
tensor([[ 0.2332, -0.4129],
        [ 1.0542, -0.9473],
        [-0.4178,  0.9416],
        [ 3.0787, -0.5175]])