torch.nn.Linear() weight Shape Explained – PyTorch Tutorial

By | January 31, 2024

torch.nn.Linear() module is widely used in deep learning model. How about the weight shape in it? In this tutorial, we will discuss it.

Syntax

torch.nn.Linear() is defined as:

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

Here in_features is the last shape of input tensor and out_features is the last shape of output tensor.

For example:

>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

We will get an output tensor with shape 128*30

How about the weight shape in torch.nn.Linear()?

Is it [in_features, out_features] or [out_features, in_features]

Look at this example:

import torch
import torch.nn as nn

m = nn.Linear(20, 30)
print(m.weight.shape)

We will get:

torch.Size([30, 20])

It means the shape of weight in torch.nn.Linear() is [out_features, in_features]

We also can find this truth from its source code.

self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))

torch.nn.Linear() weight Shape Explained - PyTorch Tutorial