Compartilhamento de tecnologia

Compreensão do eixo do pytorch

2024-07-12

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

import torch
x = torch.arange(16).reshape(1,4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
a3 = x.sum(axis = 0,keepdim=True)
a4 = x.sum(axis = 0,keepdim=False)
a5 = x.sum(axis = 2,keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
 import torch
x = torch.arange(16).reshape(4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
print(a)
print(a2)
print(x/a) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

Combine estes dois exemplos para explicar detalhadamente as mudanças no eixo sob diferentes circunstâncias.
Compreender as operações dimensionais em tensores e somatórios ao longo de eixos específicos no PyTorch leva um pouco de tempo. Analisemos estas operações passo a passo através de dois exemplos, explicando detalhadamente as mudanças de eixo em diferentes situações.

primeiro exemplo

import torch
x = torch.arange(16).reshape(1, 4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
a3 = x.sum(axis=0, keepdim=True)
a4 = x.sum(axis=0, keepdim=False)
a5 = x.sum(axis=2, keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
tensor inicial
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

Esta é uma forma de (1, 4, 4) de tensores. Podemos pensar nisso como um lote contendo uma matriz 4x4.

Soma ao longo do eixo 1
  1. x.sum(axis=1, keepdim=True)

Soma ao longo do eixo 1 (ou seja, a direção da segunda dimensão, 4), mantendo as dimensões.

tensor([[[24, 28, 32, 36]]])
  • 1

A forma torna-se (1, 1, 4)

  1. x.sum(axis=1, keepdim=False)

Soma ao longo do eixo 1, sem dimensionalidade preservada.

tensor([[24, 28, 32, 36]])
  • 1

A forma torna-se (1, 4)

Soma ao longo do eixo 0
  1. x.sum(axis=0, keepdim=True)

Soma ao longo do eixo 0 (ou seja, a direção da primeira dimensão, 1), mantendo as dimensões.

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

Como o tensor original possui apenas um elemento no eixo 0, o resultado é o mesmo que o tensor original, com forma (1, 4, 4)

  1. x.sum(axis=0, keepdim=False)

Soma ao longo do eixo 0, a dimensionalidade não é preservada.

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

A forma torna-se (4, 4)

Soma ao longo do eixo 2
  1. x.sum(axis=2, keepdim=True)

Soma ao longo do eixo 2 (ou seja, a terceira dimensão, a direção de 4), mantendo as dimensões.

tensor([[[ 6],
         [22],
         [38],
         [54]]])
  • 1
  • 2
  • 3
  • 4

A forma torna-se (1, 4, 1)

ponto chave
  • keepdim=True As dimensões somadas serão mantidas, o número de dimensões do resultado permanece inalterado, mas o tamanho das dimensões somadas passa a ser 1.
  • keepdim=False As dimensões somadas serão removidas e o número de dimensões no resultado será reduzido em 1.

segundo exemplo

import torch
x = torch.arange(16).reshape(4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
print(a)
print(a2)
print(x/a)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
tensor inicial
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

Esta é uma forma de (4, 4) de tensores.

Soma ao longo do eixo 1
  1. x.sum(axis=1, keepdim=True)

Soma ao longo do eixo 1 (ou seja, a direção da segunda dimensão, 4), mantendo as dimensões.

tensor([[ 6],
        [22],
        [38],
        [54]])
  • 1
  • 2
  • 3
  • 4

A forma torna-se (4, 1)

  1. x.sum(axis=1, keepdim=False)

Soma ao longo do eixo 1, sem dimensionalidade preservada.

tensor([ 6, 22, 38, 54])
  • 1

A forma torna-se (4,)

Elementos divididos pela soma das linhas
  1. x / a
tensor([[0.0000, 0.1667, 0.3333, 0.5000],
        [0.1818, 0.2273, 0.2727, 0.3182],
        [0.2105, 0.2368, 0.2632, 0.2895],
        [0.2222, 0.2407, 0.2593, 0.2778]])
  • 1
  • 2
  • 3
  • 4

Esta é a soma de cada elemento dividida pela sua linha correspondente, resultando em:

tensor([[ 0/6,  1/6,  2/6,  3/6],
        [ 4/22,  5/22,  6/22,  7/22],
        [ 8/38,  9/38, 10/38, 11/38],
        [12/54, 13/54, 14/54, 15/54]])
  • 1
  • 2
  • 3
  • 4

Resumo das alterações de eixo e dimensão

  • eixo=0: Opera ao longo da primeira dimensão (linha) e a soma das outras dimensões permanece após a soma.
  • eixo=1: Opera ao longo da segunda dimensão (coluna), e a soma da primeira e terceira dimensões permanece após a soma.
  • eixo=2: Operando ao longo da terceira dimensão (profundidade), a soma das duas primeiras dimensões permanece após a soma.

usar keepdim=True Ao manter dimensões, a dimensão somada torna-se 1.usarkeepdim=False Quando , as dimensões somadas são removidas.

dúvida:

Por que as linhas são listadas em vez de colunas quando reshape(1, 4, 4) é usado Somente quando reshape(4, 4) é usado, as linhas são linhas?

Estrutura tensorial

Vamos revisar os conceitos básicos primeiro:

  • Tensor 2D (matriz): Possui linhas e colunas.
  • Tensor 3D: É composto por múltiplas matrizes bidimensionais e pode ser considerado como uma pilha de matrizes com dimensão de "profundidade".

Exemplo 1: tensor bidimensional (4, 4)

import torch
x = torch.arange(16).reshape(4, 4)
print(x)
  • 1
  • 2
  • 3

Saída:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

A forma deste tensor é (4, 4), representa uma matriz 4x4:

  • OKé horizontal:

    • Linha 0: [ 0, 1, 2, 3]
    • Linha 1: [ 4, 5, 6, 7]
    • Linha 2: [ 8, 9, 10, 11]
    • Linha 3: [12, 13, 14, 15]
  • Listaé vertical:

    • Coluna 0: [ 0, 4, 8, 12]
    • Coluna 1: [ 1, 5, 9, 13]
    • Coluna 2: [ 2, 6, 10, 14]
    • Coluna 3: [ 3, 7, 11, 15]

Exemplo 2: tensor tridimensional (1, 4, 4)

x = torch.arange(16).reshape(1, 4, 4)
print(x)
  • 1
  • 2

Saída:

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

A forma deste tensor é (1, 4, 4), representa um tensor tridimensional 1x4x4:

  • A primeira dimensão é 1, indicando o tamanho do lote.
  • A segunda dimensão é 4, representa o número de linhas (linhas por matriz).
  • A terceira dimensão é 4, representa o número de colunas (colunas de cada matriz).

Explicação da soma ao longo do eixo

Soma ao longo do eixo 1 (segunda dimensão)
  1. Para um tensor bidimensional (4, 4)
a = x.sum(axis=1, keepdim=True)
print(a)
  • 1
  • 2

Saída:

tensor([[ 6],
        [22],
        [38],
        [54]])
  • 1
  • 2
  • 3
  • 4
  • O eixo 1 é a direção das linhas e os elementos de cada linha são somados:
    • [0, 1, 2, 3] => 0+1+2+3 = 6
    • [4, 5, 6, 7] => 4+5+6+7 = 22
    • [8, 9, 10, 11] => 8+9+10+11 = 38
    • [12, 13, 14, 15] => 12+13+14+15 = 54
  1. Para tensores tridimensionais (1, 4, 4)
a = x.sum(axis=1, keepdim=True)
print(a)
  • 1
  • 2

Saída:

tensor([[[24, 28, 32, 36]]])
  • 1
  • O eixo 1 é a direção da linha da primeira matriz. Some os elementos de cada linha:
    • [0, 1, 2, 3] + [4, 5, 6, 7] + [8, 9, 10, 11] + [12, 13, 14, 15]
    • Soma por coluna: 24 = 0+4+8+12, 28 = 1+5+9+13, 32 = 2+6+10+14, 36 = 3+7+11+15

Por que parece que "linhas tornam-se colunas"

existir (1, 4, 4) No tensor tridimensional, a primeira dimensão representa o tamanho do lote, então parece que cada matriz 4x4 ainda é processada de maneira bidimensional durante a operação. No entanto, como uma dimensão de lote é adicionada, ela se comporta de maneira diferente de um tensor bidimensional na operação de soma.

Especificamente:

  • Ao somar ao longo do eixo 1 (a segunda dimensão), estamos somando as linhas de cada matriz.
  • Ao somar ao longo do eixo 0 (a primeira dimensão), somamos as dimensões do lote.

Resumir

  • Tensor 2D: Os conceitos de linhas e colunas são intuitivos.
  • Tensor 3D: Após a introdução da dimensão do lote, as operações de linha e coluna parecerão diferentes, mas na verdade ainda estão operando em cada matriz bidimensional.
  • Ao somar ao longo de um eixo, compreender as dimensões da soma é fundamental.