Difference Between torch.matmul() and torch.mm() – PyTorch Tutorial

By | January 30, 2023

Both torch.matmul() and torch.mm() can perform a matrix multiplication. In this tutorial, we will introduce the difference between them.

torch.matmul()

It is defined as:

torch.matmul(input, other, *, out=None)

torch.mm()

It is defined as:

torch.mm(input, mat2, *, out=None)

Here input is a (n×m) tensor, mat2 is a (m×p) tensor, out will be a (n×p) tensor.

We also notice: torch.mm() does not broadcast. For broadcasting matrix products, we can use torch.matmul().

Here we can use an example to show the difference.

import torch
import torch.nn.functional as F

x = torch.randn(5,200)
y = torch.randn(200, 5)
print(torch.mm(x, y))
print(torch.matmul(x, y))

Run this code, we will find the results are the same.

tensor([[-27.3500,  -4.1209,  16.5225,   3.5113,  23.4709],
[ 33.1420,   3.8333,  13.8869, -13.8083, -12.5489],
[  0.1656,   6.9874,  -3.2829,  15.0789, -14.5087],
[-21.0825,  -2.0691, -14.8276,   3.2864,   6.1927],
[ -3.1685,  18.8321,   6.1552,   7.9586,   6.6930]])
tensor([[-27.3500,  -4.1209,  16.5225,   3.5113,  23.4709],
[ 33.1420,   3.8333,  13.8869, -13.8083, -12.5489],
[  0.1656,   6.9874,  -3.2829,  15.0789, -14.5087],
[-21.0825,  -2.0691, -14.8276,   3.2864,   6.1927],
[ -3.1685,  18.8321,   6.1552,   7.9586,   6.6930]])