Overview
The behavior of the torch.matmul() function
The torch.matmul()
function performs a matrix product of two tensors. The behavior depends on the dimensionality of the tensors as follows:
- If both tensors are 1-dimensional, the dot product (scalar) is returned.
- If both arguments are 2-dimensional, the matrix-matrix product is returned.
- If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
- If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
- If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiplied and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable).
Syntax and parameters
The syntax of the torch.matmul()
function is shown below:
torch.matmul(input, other, *, out=None) -> Tensor
Where:
input
(Tensor) – the first tensor to be multipliedother
(Tensor) – the second tensor to be multipliedout
(Tensor, optional) – the output tensor
Alternatives
The alternatives of the torch.matmul()
function are listed below:
torch.Tensor.matmul()
– a method that is called on the input tensor object instead of passing it as an argumenttorch.Tensor.mm()
– a method that only works for 2D tensors and performs a matrix-matrix producttorch.mm()
– a function that only works for 2D tensors and performs a matrix-matrix product
Examples
Some examples that demonstrate how to use the torch.matmul()
function in practice.
Vector x vector
This example shows how to compute the dot product of two 1D tensors using torch.matmul()
:
import torch
torch.manual_seed(2023)
# create two 1D tensors of size 3
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
# compute the dot product (scalar) using torch.matmul()
result = torch.matmul(tensor1, tensor2)
# print the result
print(result)
Output:
tensor(-1.3027)
Matrix x matrix
Perhaps this is the most common case when it comes to torch.matmul()
.
import torch
torch.manual_seed(2024)
# create two 2D tensors of size (3, 4) and (4, 5)
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4, 5)
# compute the matrix-matrix product (2D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)
# print the result
print(result)
Output:
tensor([[-1.4659, -0.7207, -6.5537, -0.1978, -4.0800],
[ 1.2412, -0.4786, -1.0352, -1.7091, -0.7333],
[-0.7588, 0.2746, 1.7423, -0.3800, 0.1791]])
Matrix x vector
This example shows how to compute the matrix-vector product of a 2D tensor and a 1D tensor with the help of torch.matmul()
:
import torch
torch.manual_seed(2024)
# create a 2D tensor of size (3, 4)
tensor1 = torch.randn(3, 4)
# create a 1D tensor of size 4
tensor2 = torch.randn(4)
# compute the matrix-vector product (1D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)
# print the result
print(result)
Output:
tensor([-2.9603, 1.0947, -1.4140])
Batched matrix x broadcasted vector
This example shows how to compute the batched matrix-vector product of a 3D tensor and a 1D tensor with torch.matmul()
. The non-matrix dimensions are broadcasted to match the batch size.
import torch
torch.manual_seed(2024)
# create a 3D tensor of size (10, 3, 4)
tensor1 = torch.randn(10, 3, 4)
# create a 1D tensor of size 4
tensor2 = torch.randn(4)
# compute the batched matrix-vector product (2D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)
# print the result
print(result)
Output:
tensor([[-1.9808, 2.2655, 0.5833],
[ 0.2252, 0.4640, 0.1878],
[ 2.9779, -0.3317, 5.9854],
[ 0.8479, 2.8258, 0.2135],
[ 1.5425, -0.1028, 1.0848],
[ 0.2890, 0.2828, 1.4975],
[-4.0526, -2.9466, -1.2563],
[-4.8923, -0.9569, -5.5869],
[ 0.2150, 3.3648, 1.5289],
[ 3.4914, 3.9322, 1.6943]])
Batched matrix x batched matrix
This example shows how to compute the batched matrix-matrix product of two 3D tensors by making use of torch.matmul()
. The non-matrix dimensions are broadcasted if they are not equal.
import torch
torch.manual_seed(2023)
# create two 3D tensors of size (10, 3, 4) and (10, 4, 5)
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
# compute the batched matrix-matrix product (3D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)
# print the result
print(result)
Output:
tensor([[[-1.7069e+00, -1.8534e+00, -2.1287e+00, -4.7003e-01, 1.0697e+00],
[ 3.8941e+00, 9.4840e-01, 3.7984e+00, -9.8924e-02, 4.7634e-02],
[-6.7208e-01, 2.0294e+00, 1.6539e+00, 4.7087e+00, -1.0837e+00]],
[[ 4.5204e-02, 3.8919e+00, -5.8897e-03, 3.9597e+00, 1.6740e+00],
[-2.6450e+00, 1.1118e+00, -7.4549e-01, 1.3357e+00, -1.5274e+00],
[ 3.2368e+00, -1.7224e+00, -2.5382e+00, 1.9547e-01, 1.4201e+00]],
[[-4.4614e-01, 2.1000e+00, 2.1122e-01, 4.0859e+00, 1.4184e+00],
[ 4.5354e-01, 8.9127e-02, 3.6503e-01, 8.5806e-01, -1.3477e+00],
[-2.4486e-01, -1.0420e+00, -6.5352e-01, -2.0812e+00, 4.1399e-01]],
[[-6.1078e-01, -2.3098e+00, 1.0434e+00, -1.9856e+00, -2.2091e+00],
[-3.2604e-02, 4.9101e-01, -2.5692e-01, -3.6945e-01, -3.3364e+00],
[ 5.4829e-02, -3.6443e-01, -6.1170e-01, -8.1428e-01, -1.8848e+00]],
[[-2.6326e+00, -3.3785e+00, -2.4368e+00, 1.2138e+00, 1.1739e+00],
[ 3.8012e+00, -2.7200e-01, 2.4414e+00, -3.1762e+00, -2.8488e+00],
[ 2.1198e+00, -1.0791e+00, 4.1354e-01, 7.2069e-01, -5.7256e-01]],
[[-2.1189e+00, -1.2500e+00, -2.3749e+00, 2.1247e-02, -8.0348e-01],
[-1.3381e+00, -3.2541e+00, -5.0417e+00, -4.2254e+00, -2.5890e+00],
[ 1.3309e+00, 5.0646e-01, -1.8184e+00, 1.6543e+00, -1.1076e+00]],
[[-1.3862e-01, 1.7765e+00, 3.9243e+00, 5.4390e+00, 3.1808e+00],
[ 1.4334e+00, 3.4047e-01, -2.9210e+00, -3.4534e+00, -2.3460e+00],
[-3.9575e+00, 4.9891e-02, 2.4687e+00, 1.8572e+00, 3.2704e+00]],
[[ 7.9100e-01, -6.5746e-01, -1.8795e+00, 4.9998e-01, -6.6892e-01],
[-3.0543e+00, 1.8478e+00, 4.5600e+00, -2.3450e+00, -1.3333e-01],
[-5.4343e-01, 7.4177e-02, 3.8904e+00, -2.0984e+00, -2.4714e+00]],
[[-1.0236e+00, -4.9922e+00, 7.3867e+00, -5.2357e+00, -1.7133e+00],
[ 5.8022e-01, -3.0840e-01, -7.3925e-02, -5.5694e-02, -2.1630e+00],
[ 4.7603e-01, 2.7360e+00, -6.6990e+00, 3.8872e+00, -2.0109e+00]],
[[-1.6855e+00, 6.5613e-02, 5.2002e+00, 2.9535e+00, -1.8489e-01],
[-2.5715e+00, -3.3931e+00, 3.4194e+00, 8.0750e-01, -2.9943e-01],
[ 2.2488e-01, -2.7440e+00, -2.9225e+00, -2.6996e+00, -8.2288e-01]]])
Batched matrix x broadcasted matrix
The example below illustrates how to compute the batched matrix-matrix product of a 3D tensor and a 2D tensor by utilizing torch.matmul()
. The non-matrix dimensions are broadcasted to match the batch size.
import torch
torch.manual_seed(2023)
# create a 3D tensor of size (10, 3, 4)
tensor1 = torch.randn(10, 3, 4)
# create a 2D tensor of size (4, 5)
tensor2 = torch.randn(4, 5)
# compute the batched matrix-matrix product (3D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)
# print the result
print(result)
Output:
tensor([[[ 3.5080e-02, -2.4184e+00, -1.2028e+00, 1.4250e+00, -1.6410e+00],
[-1.9470e-01, 2.1141e+00, 5.6419e-01, -7.0405e-01, 3.1504e+00],
[ 2.7579e+00, -7.6057e-01, 1.4620e+00, 2.5192e+00, -7.7006e-01]],
[[-4.5815e-01, 8.5896e-01, 8.3757e-01, -2.8229e+00, 2.2722e-01],
[-4.4352e-01, -1.5412e+00, -8.6720e-01, 4.4711e-01, -2.9628e+00],
[ 2.0432e+00, 3.9481e+00, 3.8507e+00, -2.5680e+00, 4.0710e+00]],
[[ 9.5828e-01, 5.6027e-01, 1.7352e+00, -1.5257e+00, -1.4508e+00],
[-1.4288e+00, -2.6703e+00, -2.4472e+00, 1.4732e+00, -3.9988e+00],
[ 5.9590e-01, 1.7817e+00, 1.0970e+00, -6.0984e-01, 3.5765e+00]],
[[-1.7303e+00, 1.8793e+00, -6.9760e-01, -5.9631e-01, 9.1063e-02],
[-1.1918e+00, 2.8247e+00, 6.6636e-01, -2.5371e+00, 1.6771e+00],
[-1.7941e+00, 8.0749e-01, -1.0960e+00, -6.4452e-01, -3.6994e-01]],
[[ 2.1279e+00, -4.7671e+00, -3.7989e-01, 2.4091e+00, -2.8382e+00],
[-1.3139e+00, 3.8104e+00, 1.6157e+00, -4.3324e+00, 7.1359e-01],
[-1.7975e+00, 5.6656e-04, -1.3343e+00, -9.2224e-01, -2.5261e-02]],
[[ 2.6325e+00, -2.0290e+00, 1.2190e+00, 1.0536e+00, 9.9029e-01],
[ 2.7471e+00, 2.9210e-01, 2.1788e+00, 6.4283e-02, 4.8397e+00],
[ 8.6415e-01, -2.7562e+00, -8.5525e-01, 2.2541e+00, -1.5677e+00]],
[[-3.5942e+00, 2.8794e+00, -1.8659e+00, -1.4797e+00, 1.5315e+00],
[ 4.3266e+00, -1.4798e+00, 2.7324e+00, 1.4158e+00, 2.3671e+00],
[-3.5717e+00, -3.1569e+00, -4.7173e+00, 1.7454e+00, -3.6640e+00]],
[[-4.6379e-01, 1.5428e+00, 6.8418e-01, -1.7160e+00, 5.0970e-01],
[-4.2143e-01, -3.1792e+00, -2.4101e+00, 3.5827e+00, -4.2032e+00],
[-2.1132e+00, -5.7218e-01, -2.4433e+00, 1.4951e+00, -2.4490e+00]],
[[ 1.4307e+00, -6.5002e+00, -2.2543e+00, 4.7914e+00, -5.2939e+00],
[-1.1652e+00, -3.6335e-01, -1.0990e+00, 3.9657e-01, -2.7056e+00],
[-1.1471e+00, 1.1457e+00, -1.0032e+00, 7.5988e-01, 1.3065e+00]],
[[-2.2915e-01, -3.8629e+00, -2.3769e+00, 2.5113e+00, -1.3395e+00],
[ 2.5257e+00, -3.6131e+00, -4.7740e-02, 2.9718e+00, 1.9309e-01],
[ 2.2912e+00, 3.0198e+00, 3.6501e+00, -2.0257e+00, 2.9431e+00]]])