2024-07-12
한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina
import torch
x = torch.arange(16).reshape(1,4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
a3 = x.sum(axis = 0,keepdim=True)
a4 = x.sum(axis = 0,keepdim=False)
a5 = x.sum(axis = 2,keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
import torch
x = torch.arange(16).reshape(4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
print(a)
print(a2)
print(x/a)
Combining these two examples, we will explain in detail the changes in the axis under different circumstances.
Understanding the dimensionality manipulation of tensors and summing along specific axes in PyTorch does take a little time. Let’s break down these operations step by step through two examples, explaining in detail how the axes change in different cases.
import torch
x = torch.arange(16).reshape(1, 4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
a3 = x.sum(axis=0, keepdim=True)
a4 = x.sum(axis=0, keepdim=False)
a5 = x.sum(axis=2, keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]])
This is a shape (1, 4, 4)
We can think of it as a batch containing 4x4 matrices.
x.sum(axis=1, keepdim=True)
Sum along axis 1 (i.e. the second dimension, direction of 4), maintaining dimensionality.
tensor([[[24, 28, 32, 36]]])
The shape becomes (1, 1, 4)
。
x.sum(axis=1, keepdim=False)
Sum along axis 1 without maintaining dimensionality.
tensor([[24, 28, 32, 36]])
The shape becomes (1, 4)
。
x.sum(axis=0, keepdim=True)
Sum along axis 0 (i.e. the first dimension, the direction of 1), maintaining the dimensionality.
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]])
Because the original tensor has only one element along axis 0, the result is the same as the original tensor, with shape (1, 4, 4)
。
x.sum(axis=0, keepdim=False)
Sum along axis 0 without maintaining dimensionality.
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
The shape becomes (4, 4)
。
x.sum(axis=2, keepdim=True)
Sum along axis 2 (i.e. the third dimension, the direction of 4), maintaining dimensionality.
tensor([[[ 6],
[22],
[38],
[54]]])
The shape becomes (1, 4, 1)
。
keepdim=True
The dimension being summed will be maintained, and the number of dimensions of the result will remain unchanged, except that the size of the dimension being summed becomes 1.keepdim=False
The dimension being summed is removed and the number of dimensions of the result is reduced by 1.import torch
x = torch.arange(16).reshape(4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
print(a)
print(a2)
print(x/a)
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
This is a shape (4, 4)
Tensor of .
x.sum(axis=1, keepdim=True)
Sum along axis 1 (i.e. the second dimension, direction of 4), maintaining dimensionality.
tensor([[ 6],
[22],
[38],
[54]])
The shape becomes (4, 1)
。
x.sum(axis=1, keepdim=False)
Sum along axis 1 without maintaining dimensionality.
tensor([ 6, 22, 38, 54])
The shape becomes (4,)
。
x / a
tensor([[0.0000, 0.1667, 0.3333, 0.5000],
[0.1818, 0.2273, 0.2727, 0.3182],
[0.2105, 0.2368, 0.2632, 0.2895],
[0.2222, 0.2407, 0.2593, 0.2778]])
This is each element divided by the sum of its corresponding row, resulting in:
tensor([[ 0/6, 1/6, 2/6, 3/6],
[ 4/22, 5/22, 6/22, 7/22],
[ 8/38, 9/38, 10/38, 11/38],
[12/54, 13/54, 14/54, 15/54]])
use keepdim=True
When maintaining dimensions, the dimension being summed becomes 1. Usekeepdim=False
, the dimension being summed is removed.
Why do rows become columns when reshape(1, 4, 4) is used? Only when reshape(4, 4) is the row a row?
Let’s review the basic concepts first:
(4, 4)
import torch
x = torch.arange(16).reshape(4, 4)
print(x)
Output:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
The shape of this tensor is (4, 4)
, represents a 4x4 matrix:
OKIt is horizontal:
[ 0, 1, 2, 3]
[ 4, 5, 6, 7]
[ 8, 9, 10, 11]
[12, 13, 14, 15]
ListIt is vertical:
[ 0, 4, 8, 12]
[ 1, 5, 9, 13]
[ 2, 6, 10, 14]
[ 3, 7, 11, 15]
(1, 4, 4)
x = torch.arange(16).reshape(1, 4, 4)
print(x)
Output:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]])
The shape of this tensor is (1, 4, 4)
, representing a 1x4x4 three-dimensional tensor:
1
, indicating the batch size.4
, represents the number of rows (rows of each matrix).4
, represents the number of columns (columns of each matrix).(4, 4)
:a = x.sum(axis=1, keepdim=True)
print(a)
Output:
tensor([[ 6],
[22],
[38],
[54]])
(1, 4, 4)
:a = x.sum(axis=1, keepdim=True)
print(a)
Output:
tensor([[[24, 28, 32, 36]]])
exist (1, 4, 4)
In the three-dimensional tensor of , the first dimension represents the batch size, so it seems that each 4x4 matrix is still processed in a two-dimensional manner during operation. However, because a batch dimension is added, it exhibits different behavior from a two-dimensional tensor in the sum operation.
Specifically: