This pithy, straightforward article will walk you through three different ways to select elements from a tensor in PyTorch. Without any further ado, let’s get started!
Indexing & Slicing
You can use the square brackets [ ]
to index a tensor by specifying the position of the elements you want to select. The example below selects the element at row 1 and column 2 from a tensor of shape (3,4):
import torch
# set a seed for reproducibility
torch.manual_seed(2023)
# create a random tensor of shape (3, 4)
x = torch.rand(3, 4)
print(x)
# tensor([[0.4290, 0.7201, 0.9481, 0.4797],
# [0.5414, 0.9906, 0.4086, 0.2183],
# [0.1834, 0.2852, 0.7813, 0.1048]])
# select the element at row 1 and column 2 (note that indexing starts at 0)
e = x[1, 2]
print(e)
# tensor(0.4086)
You can also use slicing to select a range of elements along a dimension by using the colon :
operator. This example selects the elements from rows 0 to 1 and columns 1 to 2:
import torch
# set a seed for reproducibility
torch.manual_seed(2023)
# create a random tensor of shape (3, 4)
x = torch.rand(3, 4)
print(x)
# tensor([[0.4290, 0.7201, 0.9481, 0.4797],
# [0.5414, 0.9906, 0.4086, 0.2183],
# [0.1834, 0.2852, 0.7813, 0.1048]])
# select the elements from rows 0 to 1 and columns 1 to 2
selected_elements = x[0:2, 1:3]
print(selected_elements)
# tensor([[0.7201, 0.9481],
# [0.9906, 0.4086]])
You can also use negative indices to count from the end of the dimension. For example, x[-1, -2]
will select the element at the last row and second last column. You can also use a list or a tensor of indices to select multiple elements along a dimension. For instance, x[[0, 2], :]
will select the first and third rows of x
.
Using the torch.select() function
You can use the torch.select()
function to select a single dimension from a tensor and return a new tensor with one less dimension. The syntax of this function is:
torch.select(input, dim, index) -> Tensor
Where:
input
: the input tensor that you want to select from.dim
: the dimension that you want to select.index
: the index of the dimension that you want to select.
For example, if you have a tensor x of shape (2, 3, 4), then you can use torch.select(x, 1, 2)
to select the third row of each matrix in x
and return a new tensor of shape (2, 4):
import torch
# set a seed for reproducibility
torch.manual_seed(2023)
# create a random tensor of shape (2, 3, 4)
x = torch.rand(2, 3, 4)
print(x)
# tensor([[[0.4290, 0.7201, 0.9481, 0.4797],
# [0.5414, 0.9906, 0.4086, 0.2183],
# [0.1834, 0.2852, 0.7813, 0.1048]],
# [[0.6550, 0.8375, 0.1823, 0.5239],
# [0.2432, 0.9644, 0.5034, 0.0320],
# [0.8316, 0.3807, 0.3539, 0.2114]]])
# Select the third row of each matrix in x
result = torch.select(x, 1, 2)
print(result)
# tensor([[0.1834, 0.2852, 0.7813, 0.1048],
# [0.8316, 0.3807, 0.3539, 0.2114]])
The Tensor.select()
method is equivalent to the torch.select()
function, but you can call it directly on a Tensor
object.
Using the torch.select_index() function
You can use the torch.index_select()
function (or the Tensor.index_select()
method) to select multiple dimensions from a tensor and return a new tensor with the same number of dimensions as the input tensor.
Syntax:
torch.index_select(input, dim, index) -> Tensor
Parameters explained:
input
: the input tensor that you want to select from.dim
: the dimension that you want to select.index
: a 1-D tensor containing the indices of the dimensions that you want to select.
Suppose you have a tensor x
of shape (3, 4), then you can use torch.index_select(x, 0, torch.tensor([0, 2]))
to select the first and third rows of x
and return a new tensor of shape (2, 4) like this:
# create a random tensor of shape (3, 4)
x = torch.rand(3, 4)
print(x)
# tensor([[0.4290, 0.7201, 0.9481, 0.4797],
# [0.5414, 0.9906, 0.4086, 0.2183],
# [0.1834, 0.2852, 0.7813, 0.1048]])
# Select the first and third rows of the tensor
result = torch.index_select(x, 0, torch.tensor([0, 2]))
print(result)
# tensor([[0.4290, 0.7201, 0.9481, 0.4797],
# [0.1834, 0.2852, 0.7813, 0.1048]])
This tutorial ends here. Happy coding & have a nice day!