Exact equality comparison
Exactly equality comparison means checking if two PyTorch tensors have the same shape, dtype, and values. It returns True if they are exactly the same and False otherwise. To deal with this kind of comparison, we can use the built-in function torch.equal().
Example:
import torch
# Create some tensors
x = torch.tensor([
[1, 2],
[3, 4],
])
y = torch.tensor([
[1, 2],
[3, 4],
])
z = torch.tensor([
[1, 2, 3],
[4, 5, 6],
])
# Compare x and y
print(torch.equal(x, y))
# Compare x and z
print(torch.equal(x, z))
Output:
True
False
Shape and dtype comparison
Shape and type comparison means checking if two given PyTorch tensors have the same shape and dtype but not necessarily the same values. You can use tensor_one.shape == tensor_two.shape and tensor_one.dtype == tensor_two.dtype which return boolean values.
Example:
import torch
a = torch.tensor( [ [1., 2.], [3., 4.]])
b = torch.tensor( [ [1., 2.], [3., 4.]])
c = torch.tensor( [ [1., 2.]])
d = torch.tensor( [ [1, 2], [3, 4]])
print(a.shape == b.shape) # True
print(a.shape == c.shape) # False
print(a.dtype == b.dtype) # True
print(a.dtype == d.dtype) # False
Approximate equality comparison
You may want to use this kind of comparison when you want to check if two tensors are close enough at each position within some tolerance for floating point differences. You can use the torch.allclose(input, other) function which returns a boolean value to do the job. You can also specify the tolerance (epsilon) as an argument.
Example:
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.0001, 2.0002, 3.0003])
c = torch.tensor([1.01, 2.02, 3.03])
print(torch.allclose(a, b, rtol=0.001)) # True
print(torch.allclose(a, c)) # False
print(torch.allclose(a, c, atol=0.03)) # True
Element-wise comparison
Use this type of comparison when you want to compare two tensors element-wise and get a tensor of booleans as a result. The torch.eq(tensor_one, tensor_two) function can help you in this situation.
Example:
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 4, 3])
c = torch.tensor([4, 5, 6])
print(torch.eq(a, b))
# Output: tensor([ True, False, True])
print(torch.eq(a, c))
# Output: tensor([False, False, False])