torch.min() and torch.max()
In PyTorch, you can make use of the built-in functions torch.min()
and torch.max()
to find the minimum and maximum values of the whole tensor or along a given dimension. These functions return either a single value or a tuple of values and indices, depending on the input arguments.
Example:
import torch
# set the seed for generating random numbers
torch.manual_seed(2023)
# create a random tensor of shape (3, 4)
a = torch.randn(3, 4)
print(a)
# tensor([[-1.2075, 0.5493, -0.3856, 0.6910],
# [-0.7424, 0.1570, 0.0721, 1.1055],
# [ 0.2218, -0.0794, -1.0846, -1.5421]])
# find the minimum value of the whole tensor
min_a = torch.min(a)
print(min_a)
# tensor(-1.5421)
# find the maximum value of the whole tensor
max_a = torch.max(a)
print(max_a)
# tensor(1.1055)
# find the minimum value of each row
min_a_row = torch.min(a, dim=1)
print(min_a_row)
# torch.return_types.min(
# values=tensor([-1.2075, -0.7424, -1.5421]),
# indices=tensor([0, 0, 3])
# )
# find the maximum value of each row
max_a_row = torch.max(a, dim=1)
print(max_a_row)
# torch.return_types.max(
# values=tensor([0.6910, 1.1055, 0.2218]),
# indices=tensor([3, 3, 0])
# )
torch.argmin() and torch.argmax()
PyTorch also brings to the table the functions torch.argmin()
and torch.argmax()
to get the indices of the minimum and maximum values of a tensor, respectively. They are similar to the torch.min()
and torch.max()
functions, but they only return the indices, not the values. They can also operate on the whole tensor or on a specific dimension, and return either a single index or a tensor of indices, depending on the input arguments.
A code example is worth more than a thousand words:
import torch
# set the seed for generating random numbers
torch.manual_seed(2023)
# create a random tensor of shape (3, 4)
a = torch.randn(3, 4)
print(a)
# tensor([[-1.2075, 0.5493, -0.3856, 0.6910],
# [-0.7424, 0.1570, 0.0721, 1.1055],
# [ 0.2218, -0.0794, -1.0846, -1.5421]])
# find the index of the minimum value of the whole tensor
argmin_a = torch.argmin(a)
print(argmin_a)
# tensor(11)
# find the index of the maximum value of the whole tensor
argmax_a = torch.argmax(a)
print(argmax_a)
# tensor(7)
# find the index of the minimum value of each row
argmin_a_row = torch.argmin(a, dim=1)
print(argmin_a_row)
# tensor([0, 0, 3])
# find the index of the maximum value of each row
argmax_a_row = torch.argmax(a, dim=1)
print(argmax_a_row)
# tensor([3, 3, 0])