This concise, practical article is about stacking tensors in PyTorch with the torch.stack()
, torch.vstack()
, and torch.hstack()
functions.
torch.stack()
Syntax & Parameters
torch.stack()
is a PyTorch function that joins or concatenates a sequence of tensors along a new dimension. It inserts a new dimension and concatenates the tensors along that dimension. The tensors must have the same shape and size to be stacked.
Syntax:
torch.stack(tensors, dim=0, *, out=None) -> Tensor
Where:
tensors
: a sequence of tensors to concatenate. They must have the same shape and size.dim
: an integer that specifies the dimension to insert. It must be between 0 and the number of dimensions of the concatenated tensors (inclusive).out
: an optional tensor that stores the output. It must have the same shape and size as the expected output.
The function returns a tensor that is the concatenation of the input tensors along the specified dimension.
Example
Let’s say you have two tensors a
and b
with shape (3, 4)
:
a = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
b = torch.tensor([[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]])
You can stack them along the first dimension (dim=0
) to get a tensor c
with shape (2, 3, 4)
:
c = torch.stack([a, b], dim=0)
print(c)
# tensor([[[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]],
#
# [[13, 14, 15, 16],
# [17, 18, 19, 20],
# [21, 22, 23, 24]]])
You can see that the tensors a
and b
are concatenated along a new dimension at the beginning of the tensor c
. The first slice of c
along dim=0 is equal to a
, and the second slice is equal to b
. You can also stack them along the second dimension (dim=1) to get a tensor d
with shape (3, 2, 4)
:
d = torch.stack([a, b], dim=1)
print(d)
# tensor([[[ 1, 2 ,3 ,4 ],
# [13 ,14 ,15 ,16]],
#
# [[5 ,6 ,7 ,8 ],
# [17 ,18 ,19 ,20]],
#
# [[9 ,10 ,11 ,12],
# [21 ,22 ,23 ,24]]])
You can see that the tensors a
and b
are concatenated along a new dimension in the middle of the tensor d
. The first column of d
along dim=1 is equal to a
, and the second column is equal to b
.
torch.vstack()
Syntax & Parameters
The torch.vstack()
function is used to stack tensors in sequence vertically (row wise). This means that it concatenates tensors along the first axis, or the dimension that represents the rows of a matrix. For example, if you have two tensors A
and B
, each with shape (3, 4)
, then torch.vstack((A, B))
will return a tensor with shape (6, 4)
, where the first three rows are from A
and the last three rows are from B
.
Syntax:
torch.vstack(tensors, *, out=None) -> Tensor
Where:
tensors
: a sequence of tensors to be stacked vertically. The tensors must have the same shape along all dimensions except the first dimension. They can be 1-D or higher dimensional tensors, but they will be reshaped to be at least 2-D by usingtorch.atleast_2d()
.out
: an optional argument to specify a pre-allocated output tensor to store the result. The output tensor must have the correct shape anddtype
to hold the stacked tensors. Ifout
isNone
, a new tensor will be allocated and returned.
Examples
Here is an example of using torch.vstack()
with two 1-D tensors:
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.vstack((x, y))
print(z)
Output:
tensor([[1, 2, 3],
[4, 5, 6]])
You can see that x
and y
are reshaped to be (1, 3)
tensors and then stacked vertically to form a (2, 3)
tensor.
Another example that demonstrates how the function works with tensors that aren’t 1-D:
import torch
a = torch.tensor([
[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]
])
print(a.shape)
# (2, 2, 3)
b = torch.tensor([
[[13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24]]
])
print(b.shape)
# (2, 2, 3)
c = torch.vstack([a, b])
print(c.shape)
# (4, 2, 3)
print(c)
# tensor([[[ 1, 2, 3],
# [ 4, 5, 6]],
# [[ 7, 8, 9],
# [10, 11, 12]],
# [[13, 14, 15],
# [16, 17, 18]],
# [[19, 20, 21],
# [22, 23, 24]]])
torch.hstack()
Syntax & Parameters
torch.hstack()
is a function that stacks tensors in sequence horizontally (column wise). It is equivalent to concatenation along the first axis for 1-D tensors, and along the second axis for all other tensors.
Syntax:
torch.hstack(tensors, *, out=None) -> Tensor
Parameters:
tensors
: a sequence of tensors to concatenate.out
: an optional output tensor to store the result.
Examples
Using torch.hstack()
with 1-D tensors:
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.hstack((x, y))
print(z)
# tensor([1, 2, 3, 4, 5, 6])
Another example:
import torch
a = torch.tensor([
[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]
])
print(a.shape)
# (2, 2, 3)
b = torch.tensor([
[[13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24]]
])
print(b.shape)
# (2, 2, 3)
c = torch.hstack((a, b))
print(c.shape)
#(2, 4, 3)
print(c)
# tensor([[[ 1, 2, 3],
# [ 4, 5, 6],
# [13, 14, 15],
# [16, 17, 18]],
# [[ 7, 8, 9],
# [10, 11, 12],
# [19, 20, 21],
# [22, 23, 24]]])