# Implement Sequence Mask to Shield Tensor in PyTorch – PyTorch Tutorial

By | August 31, 2023

There does not exist a function like tf.sequence_mask() in pytorch. It is not easy to process variable length sequence. In this tutorial, we will introduce you how to create a function like tf.sequence_mask().

As to a variable length sequence, we may pad it to a fixed sequence using 0, however,  we should shield them by a mask. A sequence mask look like:

## How to implement sequence mask in pytorch?

We will use an example to show you how to do.

For example:

import torch
import numpy as np

if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length

Then, we can use it as follows:

# 3 sentence, 5 words
lx = np.array([[2],[3],[4]])
length = torch.from_numpy(lx)
print(length.shape)

max_lx = 5
print(x.shape)
print(x)

Run this code, we will see:

torch.Size([3, 1])
torch.Size([3, 5])
tensor([[ True,  True, False, False, False],
[ True,  True,  True, False, False],
[ True,  True,  True,  True, False]])

Here max_length = 5. The length shape is [3, 1], we will get a mask tensor with shape [3, 5]

Moreover, we can see the data type of mask tensor x is boolean. In order to get a float mask tensor we can do as follows:

print(x.float())

Then, we will get:

tensor([[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.]])

Moreover, we also can create another version. For example:

def sequence_mask(sequence_length, max_len=None):

Args:
sequence_length (torch.tensor): Sequence lengths.
max_len (int, Optional): Maximum sequence length. Defaults to None.

Shapes:
- mask: :math:[B, T_max]
"""
if max_len is None:
max_len = sequence_length.max()
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
# B x T_max
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)

Here the shape of sequence_length will be [B,], not [B, 1].

Finally, we also will get a mask tensor with the shape [B, T_max]