When working with PyTorch, there might be cases where you need to change the data type of a tensor for some reason, such as to match the data type of another tensor or a scalar operand in an arithmetic operation or to reduce the memory usage or increase the numerical precision.
This succinct, straight-to-the-point article shows you 2 different ways to change the data type of a given PyTorch tensor.
Using the type() method
When calling the type() method on the tensor object, it returns the type if no argument is given or casts the tensor to the specified type if an argument is given.
In the example below, we will change the dtype of a tensor from torch.float32 to torch.int32:
import torch
input_tensor = torch.tensor([1, 2, 3], dtype=torch.float32)
# casting type
output_tensor = input_tensor.type(torch.int32)
# print the type
print(output_tensor.dtype)
Output:
torch.int32
When casting a tensor from a high precision to a lower precision, you should be aware of some possible consequences, such as:
- Loss of information: Casting a tensor to a lower precision may result in losing some information due to rounding or truncation. For example, if you have a tensor x with values [1.23, 4.56, 7.89] and you cast it to torch.int32, the fractional part will be truncated, and the values will become [1, 4, 7].
- Overflow or underflow: Casting a tensor to a lower precision may cause some values to exceed the range of the lower precision type and become either the maximum or minimum value that can be represented by that type. This is called overflow or underflow. For example, if you have a tensor x with values [1e10, -1e10] and you cast it to torch.int32, the values will become [2147483647, -2147483648], which are the maximum and minimum values that can be represented by torch.int32.
- Numerical errors or instability: Casting a tensor to a lower precision may introduce some numerical errors or instability in the computation due to the reduced accuracy and range of the lower precision type. For example, if you have a tensor x with values [0.1, 0.2] and you cast it to torch.float16, the values will become [0.09998, 0.19995]. If you then add them together, you will get 0.29993 instead of 0.3.
Using the to() method
Besides the type() method, you can use the to() method with an argument specifying the desired type. In the example below, we will change the dtype of a tensor from torch.float32 to torch.int32:
import torch
# the default dtype is torch.float32
old_tensor = torch.zeros(3, 4)
# convert it to a torch.int32 tensor
new_tensor = old_tensor.to(torch.int32)
# print the dtype of new_tensor
print(new_tensor.dtype)
Output:
torch.int32
The type() and the to() methods are both used to change the data type of a PyTorch tensor, but they have some differences:
- The type() method can only change the data type of a tensor, while the to() method can also change other attributes of a tensor, such as the device and the layout.
- The type() method returns the data type of a tensor if no argument is given, while the to() method raises an error if no argument is given.
- The type() method accepts a torch.dtype object as an argument, while the to() method accepts a torch.dtype object or another tensor as an argument. If another tensor is given, the to() method will change the data type, device, and layout of the original tensor to match the given tensor (see the example below for more clarity).
Example:
import torch
x = torch.zeros(2, 3)
print("The data type of x in the beginning is", x.dtype)
y = torch.tensor([
[1, 2, 3],
[4, 5, 6]
], dtype=torch.int16
)
x = x.to(y)
print("The data type of x after x.to(y) is", x.dtype)
Output:
The data type of x in the beginning is torch.float32
The data type of x after x.to(y) is torch.int16