Technology Sharing

Understanding of pytorch's axis

2024-07-12

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

import torch
x = torch.arange(16).reshape(1,4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
a3 = x.sum(axis = 0,keepdim=True)
a4 = x.sum(axis = 0,keepdim=False)
a5 = x.sum(axis = 2,keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
 import torch
x = torch.arange(16).reshape(4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
print(a)
print(a2)
print(x/a) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

Combining these two examples, we will explain in detail the changes in the axis under different circumstances.
Understanding the dimensionality manipulation of tensors and summing along specific axes in PyTorch does take a little time. Let’s break down these operations step by step through two examples, explaining in detail how the axes change in different cases.

First example

import torch
x = torch.arange(16).reshape(1, 4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
a3 = x.sum(axis=0, keepdim=True)
a4 = x.sum(axis=0, keepdim=False)
a5 = x.sum(axis=2, keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
Initial Tensor
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

This is a shape (1, 4, 4) We can think of it as a batch containing 4x4 matrices.

Sum along axis 1
  1. x.sum(axis=1, keepdim=True)

Sum along axis 1 (i.e. the second dimension, direction of 4), maintaining dimensionality.

tensor([[[24, 28, 32, 36]]])
  • 1

The shape becomes (1, 1, 4)

  1. x.sum(axis=1, keepdim=False)

Sum along axis 1 without maintaining dimensionality.

tensor([[24, 28, 32, 36]])
  • 1

The shape becomes (1, 4)

Sum along axis 0
  1. x.sum(axis=0, keepdim=True)

Sum along axis 0 (i.e. the first dimension, the direction of 1), maintaining the dimensionality.

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

Because the original tensor has only one element along axis 0, the result is the same as the original tensor, with shape (1, 4, 4)

  1. x.sum(axis=0, keepdim=False)

Sum along axis 0 without maintaining dimensionality.

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

The shape becomes (4, 4)

Sum along axis 2
  1. x.sum(axis=2, keepdim=True)

Sum along axis 2 (i.e. the third dimension, the direction of 4), maintaining dimensionality.

tensor([[[ 6],
         [22],
         [38],
         [54]]])
  • 1
  • 2
  • 3
  • 4

The shape becomes (1, 4, 1)

key point
  • keepdim=True The dimension being summed will be maintained, and the number of dimensions of the result will remain unchanged, except that the size of the dimension being summed becomes 1.
  • keepdim=False The dimension being summed is removed and the number of dimensions of the result is reduced by 1.

Second example

import torch
x = torch.arange(16).reshape(4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
print(a)
print(a2)
print(x/a)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
Initial Tensor
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

This is a shape (4, 4) Tensor of .

Sum along axis 1
  1. x.sum(axis=1, keepdim=True)

Sum along axis 1 (i.e. the second dimension, direction of 4), maintaining dimensionality.

tensor([[ 6],
        [22],
        [38],
        [54]])
  • 1
  • 2
  • 3
  • 4

The shape becomes (4, 1)

  1. x.sum(axis=1, keepdim=False)

Sum along axis 1 without maintaining dimensionality.

tensor([ 6, 22, 38, 54])
  • 1

The shape becomes (4,)

Divide element by row sum
  1. x / a
tensor([[0.0000, 0.1667, 0.3333, 0.5000],
        [0.1818, 0.2273, 0.2727, 0.3182],
        [0.2105, 0.2368, 0.2632, 0.2895],
        [0.2222, 0.2407, 0.2593, 0.2778]])
  • 1
  • 2
  • 3
  • 4

This is each element divided by the sum of its corresponding row, resulting in:

tensor([[ 0/6,  1/6,  2/6,  3/6],
        [ 4/22,  5/22,  6/22,  7/22],
        [ 8/38,  9/38, 10/38, 11/38],
        [12/54, 13/54, 14/54, 15/54]])
  • 1
  • 2
  • 3
  • 4

Summary of axis and dimension changes

  • axis=0: Operate along the first dimension (rows), and after summing, the remaining sum is the sum of the other dimensions.
  • axis=1: Operate along the second dimension (column), and after summing, the remaining part is the sum of the first and third dimensions.
  • axis=2: Operate along the third dimension (depth), and the sum remains after summing the first two dimensions.

use keepdim=True When maintaining dimensions, the dimension being summed becomes 1. Usekeepdim=False , the dimension being summed is removed.

doubt:

Why do rows become columns when reshape(1, 4, 4) is used? Only when reshape(4, 4) is the row a row?

The structure of a tensor

Let’s review the basic concepts first:

  • 2D tensor (matrix): has rows and columns.
  • 3D Tensor: It is composed of multiple two-dimensional matrices and can be regarded as a stack of matrices with a "depth" dimension.

Example 1: 2D Tensor (4, 4)

import torch
x = torch.arange(16).reshape(4, 4)
print(x)
  • 1
  • 2
  • 3

Output:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

The shape of this tensor is (4, 4), represents a 4x4 matrix:

  • OKIt is horizontal:

    • Line 0: [ 0, 1, 2, 3]
    • Line 1: [ 4, 5, 6, 7]
    • Line 2: [ 8, 9, 10, 11]
    • Line 3: [12, 13, 14, 15]
  • ListIt is vertical:

    • Column 0: [ 0, 4, 8, 12]
    • Column 1: [ 1, 5, 9, 13]
    • Column 2: [ 2, 6, 10, 14]
    • Column 3: [ 3, 7, 11, 15]

Example 2: 3D Tensor (1, 4, 4)

x = torch.arange(16).reshape(1, 4, 4)
print(x)
  • 1
  • 2

Output:

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

The shape of this tensor is (1, 4, 4), representing a 1x4x4 three-dimensional tensor:

  • The first dimension is 1, indicating the batch size.
  • The second dimension is 4, represents the number of rows (rows of each matrix).
  • The third dimension is 4, represents the number of columns (columns of each matrix).

Explanation of summing along an axis

Sum along axis 1 (the second dimension)
  1. For a two-dimensional tensor (4, 4)
a = x.sum(axis=1, keepdim=True)
print(a)
  • 1
  • 2

Output:

tensor([[ 6],
        [22],
        [38],
        [54]])
  • 1
  • 2
  • 3
  • 4
  • Axis 1 is the row direction, and the elements of each row are summed:
    • [0, 1, 2, 3] => 0+1+2+3 = 6
    • [4, 5, 6, 7] => 4+5+6+7 = 22
    • [8, 9, 10, 11] => 8+9+10+11 = 38
    • [12, 13, 14, 15] => 12+13+14+15 = 54
  1. For a three-dimensional tensor (1, 4, 4)
a = x.sum(axis=1, keepdim=True)
print(a)
  • 1
  • 2

Output:

tensor([[[24, 28, 32, 36]]])
  • 1
  • Axis 1 is the row direction of the first matrix. The elements of each row are summed:
    • [0, 1, 2, 3] + [4, 5, 6, 7] + [8, 9, 10, 11] + [12, 13, 14, 15]
    • Sum each column: 24 = 0+4+8+12, 28 = 1+5+9+13, 32 = 2+6+10+14, 36 = 3+7+11+15

Why does it seem that "rows have become columns"?

exist (1, 4, 4) In the three-dimensional tensor of , the first dimension represents the batch size, so it seems that each 4x4 matrix is ​​still processed in a two-dimensional manner during operation. However, because a batch dimension is added, it exhibits different behavior from a two-dimensional tensor in the sum operation.

Specifically:

  • When we sum along axis 1 (the second dimension), we are summing across the rows of each matrix.
  • When summing along axis 0 (the first dimension), we sum over the batch dimension.

Summarize

  • 2D Tensor: The concept of rows and columns is intuitive.
  • 3D Tensor: After the batch dimension is introduced, the operations on rows and columns will look different, but in fact they are still operations on each two-dimensional matrix.
  • When summing along an axis, understanding the dimension over which you are summing is key.