PyTorch: Squeezing and Unsqueezing Tensors

Updated: July 22, 2023 By: Frienzied Flame Post a comment

Squeezing and unsqueezing a tensor are two operations that can change the shape of a tensor by adding or removing dimensions of size 1. This concise, straight-to-the-point article is about squeezing and unsqueezing tensors in PyTorch by using the torch.squeeze() and torch.unsqueeze() functions, respectively.

Squeezing tensors

The torch.squeeze() function in PyTorch is used for manipulating a tensor by dropping all its dimensions of input having size 1. For example, if you have a tensor of shape (A x 1 x B x C x 1 x D), then the squeeze function will return a tensor of shape (A x B x C x D) by removing the dimensions of size 1. This can help reduce the memory usage and simplify the tensor operations. You can also specify a dim argument to squeeze only a specific dimension of the input tensor.

Syntax:

torch.squeeze(input, dim=None) -> Tensor

Where:

  • input: the input tensor that you want to squeeze.
  • dim: an optional integer that specifies which dimension of the input tensor to squeeze. If dim is not given, then all the dimensions of size 1 will be squeezed. If dim is given, then only the dimension at that position will be squeezed, if it has size 1.

Here is a simple code snippet that demonstrates how to create a tensor with some dimensions of size 1 and then squeeze it using the torch.squeeze() function:

# Import PyTorch library
import torch

# Create a tensor of shape (2 x 1 x 3 x 4 x 1 x 5)
x = torch.rand(2, 1, 3, 4, 1, 5)
print("The shape of x is:", x.shape)

# Squeeze the tensor to remove all the dimensions of size 1
y = torch.squeeze(x)
print("The shape of y is:", y.shape)

# Squeeze the tensor only in the second dimension
z = torch.squeeze(x, dim=1)
print("The shape of z is:", z.shape)

Output:

The shape of x is: torch.Size([2, 1, 3, 4, 1, 5])
The shape of y is: torch.Size([2, 3, 4, 5])
The shape of z is: torch.Size([2, 3, 4, 1, 5])

Unsqueezing tensors

The torch.unsqueeze() function is the opposite of the torch.squeeze() function. It inserts a new dimension of size 1 at the specified position of the input tensor. This function is useful when you want to specify the number of dimensions for a particular operation or layer, such as a convolutional layer. For example, if you have a tensor of shape (A x B x C), then the unsqueeze function can add a new dimension at any position, such as (A x 1 x B x C), (A x B x 1 x C), or (A x B x C x 1).

Syntax:

torch.unsqueeze(input, dim) -> Tensor

Where:

  • input: the input tensor that you want to unsqueeze.
  • dim: an integer that specifies which position to insert the new dimension of size 1. A dim value within the range [-input.dim() - 1, input.dim() + 1] can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.

Example:

# Import PyTorch library
import torch

# Create a tensor of shape (3 x 4)
x = torch.rand(3, 4)
print("The shape of x is:", x.shape)

# Unsqueeze the tensor in the first dimension
y = torch.unsqueeze(x, dim=0)
print("The shape of y is:", y.shape)

# Unsqueeze the tensor in the second dimension
z = torch.unsqueeze(x, dim=1)
print("The shape of z is:", z.shape)

# Unsqueeze the tensor in the last dimension
w = torch.unsqueeze(x, dim=-1)
print("The shape of w is:", w.shape)

Output:

The shape of x is: torch.Size([3, 4])
The shape of y is: torch.Size([1, 3, 4])
The shape of z is: torch.Size([3, 1, 4])
The shape of w is: torch.Size([3, 4, 1])