Squeezing and unsqueezing a tensor are two operations that can change the shape of a tensor by adding or removing dimensions of size 1
. This concise, straight-to-the-point article is about squeezing and unsqueezing tensors in PyTorch by using the torch.squeeze()
and torch.unsqueeze()
functions, respectively.
Squeezing tensors
The torch.squeeze()
function in PyTorch is used for manipulating a tensor by dropping all its dimensions of input having size 1
. For example, if you have a tensor of shape (A x 1 x B x C x 1 x D)
, then the squeeze function will return a tensor of shape (A x B x C x D)
by removing the dimensions of size 1
. This can help reduce the memory usage and simplify the tensor operations. You can also specify a dim
argument to squeeze only a specific dimension of the input tensor.
Syntax:
torch.squeeze(input, dim=None) -> Tensor
Where:
input
: the input tensor that you want to squeeze.dim
: an optional integer that specifies which dimension of the input tensor to squeeze. Ifdim
is not given, then all the dimensions of size1
will be squeezed. Ifdim
is given, then only the dimension at that position will be squeezed, if it has size1
.
Here is a simple code snippet that demonstrates how to create a tensor with some dimensions of size 1
and then squeeze it using the torch.squeeze()
function:
# Import PyTorch library
import torch
# Create a tensor of shape (2 x 1 x 3 x 4 x 1 x 5)
x = torch.rand(2, 1, 3, 4, 1, 5)
print("The shape of x is:", x.shape)
# Squeeze the tensor to remove all the dimensions of size 1
y = torch.squeeze(x)
print("The shape of y is:", y.shape)
# Squeeze the tensor only in the second dimension
z = torch.squeeze(x, dim=1)
print("The shape of z is:", z.shape)
Output:
The shape of x is: torch.Size([2, 1, 3, 4, 1, 5])
The shape of y is: torch.Size([2, 3, 4, 5])
The shape of z is: torch.Size([2, 3, 4, 1, 5])
Unsqueezing tensors
The torch.unsqueeze()
function is the opposite of the torch.squeeze()
function. It inserts a new dimension of size 1
at the specified position of the input tensor. This function is useful when you want to specify the number of dimensions for a particular operation or layer, such as a convolutional layer. For example, if you have a tensor of shape (A x B x C)
, then the unsqueeze function can add a new dimension at any position, such as (A x 1 x B x C)
, (A x B x 1 x C)
, or (A x B x C x 1)
.
Syntax:
torch.unsqueeze(input, dim) -> Tensor
Where:
input
: the input tensor that you want to unsqueeze.dim
: an integer that specifies which position to insert the new dimension of size1
. Adim
value within the range[-input.dim() - 1, input.dim() + 1]
can be used. Negativedim
will correspond tounsqueeze()
applied atdim = dim + input.dim() + 1
.
Example:
# Import PyTorch library
import torch
# Create a tensor of shape (3 x 4)
x = torch.rand(3, 4)
print("The shape of x is:", x.shape)
# Unsqueeze the tensor in the first dimension
y = torch.unsqueeze(x, dim=0)
print("The shape of y is:", y.shape)
# Unsqueeze the tensor in the second dimension
z = torch.unsqueeze(x, dim=1)
print("The shape of z is:", z.shape)
# Unsqueeze the tensor in the last dimension
w = torch.unsqueeze(x, dim=-1)
print("The shape of w is:", w.shape)
Output:
The shape of x is: torch.Size([3, 4])
The shape of y is: torch.Size([1, 3, 4])
The shape of z is: torch.Size([3, 1, 4])
The shape of w is: torch.Size([3, 4, 1])