Implement Sequence Mask to Shield Tensor in PyTorch – PyTorch Tutorial

By | August 31, 2023

There does not exist a function like tf.sequence_mask() in pytorch. It is not easy to process variable length sequence. In this tutorial, we will introduce you how to create a function like tf.sequence_mask().

What is sequence mask?

As to a variable length sequence, we may pad it to a fixed sequence using 0, however,  we should shield them by a mask. A sequence mask look like:

understand tensorflow tf.sequence_mask() with examples

How to implement sequence mask in pytorch?

We will use an example to show you how to do.

For example:

import torch
import numpy as np

def sequence_mask(length, max_length=None):
  if max_length is None:
    max_length = length.max()
  x = torch.arange(max_length, dtype=length.dtype, device=length.device)
  return x.unsqueeze(0) < length

sequence_mask() function will receive two parameters: length, max_length, which is same to tf.sequence_mask().

Then, we can use it as follows:

# 3 sentence, 5 words
lx = np.array([[2],[3],[4]])
length = torch.from_numpy(lx)

max_lx = 5
x = sequence_mask(length, max_lx)

Run this code, we will see:

torch.Size([3, 1])
torch.Size([3, 5])
tensor([[ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False]])

Here max_length = 5. The length shape is [3, 1], we will get a mask tensor with shape [3, 5]

Moreover, we can see the data type of mask tensor x is boolean. In order to get a float mask tensor we can do as follows:


Then, we will get:

tensor([[1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.]])

Moreover, we also can create another version. For example:

def sequence_mask(sequence_length, max_len=None):
    """Create a sequence mask for filtering padding in a sequence tensor.

        sequence_length (torch.tensor): Sequence lengths.
        max_len (int, Optional): Maximum sequence length. Defaults to None.

        - mask: :math:`[B, T_max]`
    if max_len is None:
        max_len = sequence_length.max()
    seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
    # B x T_max
    return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)

Here the shape of sequence_length will be [B,], not [B, 1].

Finally, we also will get a mask tensor with the shape [B, T_max]