Understand transformers.get_linear_schedule_with_warmup() with Examples – PyTorch Tutorial

By | January 18, 2023

transformers.get_linear_schedule_with_warmup() create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. It is similar to transformers.get_cosine_schedule_with_warmup(). In this tutorial, we will use an example to show you how to use it.

Syntax

transformers.get_linear_schedule_with_warmup() is defined as:

transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=- 1)

We should notice: similar to transformers.get_cosine_schedule_with_warmup(), transformers.get_linear_schedule_with_warmup() is also be used after batch training, not epoch training.

How to use it?

Here is an example:

import transformers
import torch

if __name__ == "__main__":
    from matplotlib import pyplot as plt

    lr_list = []
    model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
    LR = 0.001

    max_epoch = 10
    train_dataset_len = 1000
    batch_size = 64
    train_iters_per_epoch = train_dataset_len // batch_size
    num_training_steps = max_epoch*train_iters_per_epoch
    num_warmup_steps = int(num_training_steps * 0.1)

    optimizer = torch.optim.Adam(model, lr=LR, weight_decay=2e-5)

    scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
    for epoch in range(max_epoch):

        for i in range(train_iters_per_epoch):
            optimizer.zero_grad()
            optimizer.step()
            lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
            scheduler.step()
    plt.plot(range(max_epoch*train_iters_per_epoch), lr_list, color='r')
    plt.show()

You can see the code explanation in this tutorial.

Understand transformers.get_cosine_schedule_with_warmup() with Examples – PyTorch Tutorial

Run this code, we will see:

Understand transformers.get_linear_schedule_with_warmup() with Examples - PyTorch Tutorial