Compartir tecnología

Comprensión del eje de pytorch

2024-07-12

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

import torch
x = torch.arange(16).reshape(1,4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
a3 = x.sum(axis = 0,keepdim=True)
a4 = x.sum(axis = 0,keepdim=False)
a5 = x.sum(axis = 2,keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
 import torch
x = torch.arange(16).reshape(4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
print(a)
print(a2)
print(x/a) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

Combine estos dos ejemplos para explicar en detalle los cambios en el eje en diferentes circunstancias.
Comprender las operaciones dimensionales en tensores y la suma a lo largo de ejes específicos en PyTorch lleva un poco de tiempo. Analicemos estas operaciones paso a paso a través de dos ejemplos, explicando en detalle los cambios de eje en diferentes situaciones.

primer ejemplo

import torch
x = torch.arange(16).reshape(1, 4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
a3 = x.sum(axis=0, keepdim=True)
a4 = x.sum(axis=0, keepdim=False)
a5 = x.sum(axis=2, keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
tensor inicial
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

Esta es una forma de (1, 4, 4) de tensores. Podemos considerarlo como un lote que contiene una matriz de 4x4.

Suma a lo largo del eje 1
  1. x.sum(axis=1, keepdim=True)

Suma a lo largo del eje 1 (es decir, la dirección de la segunda dimensión, 4), manteniendo las dimensiones.

tensor([[[24, 28, 32, 36]]])
  • 1

La forma se vuelve (1, 1, 4)

  1. x.sum(axis=1, keepdim=False)

Suma a lo largo del eje 1, no se conserva ninguna dimensionalidad.

tensor([[24, 28, 32, 36]])
  • 1

La forma se vuelve (1, 4)

Suma a lo largo del eje 0
  1. x.sum(axis=0, keepdim=True)

Suma a lo largo del eje 0 (es decir, la dirección de la primera dimensión, 1), manteniendo las dimensiones.

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

Debido a que el tensor original tiene solo un elemento en el eje 0, el resultado es el mismo que el tensor original, con forma (1, 4, 4)

  1. x.sum(axis=0, keepdim=False)

Suma a lo largo del eje 0, la dimensionalidad no se conserva.

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

La forma se vuelve (4, 4)

Suma a lo largo del eje 2
  1. x.sum(axis=2, keepdim=True)

Suma a lo largo del eje 2 (es decir, la tercera dimensión, la dirección de 4), manteniendo las dimensiones.

tensor([[[ 6],
         [22],
         [38],
         [54]]])
  • 1
  • 2
  • 3
  • 4

La forma se vuelve (1, 4, 1)

punto clave
  • keepdim=True Las dimensiones sumadas se mantendrán, el número de dimensiones del resultado permanece sin cambios, pero el tamaño de las dimensiones sumadas pasa a ser 1.
  • keepdim=False Las dimensiones sumadas se eliminarán y el número de dimensiones en el resultado se reducirá en 1.

segundo ejemplo

import torch
x = torch.arange(16).reshape(4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
print(a)
print(a2)
print(x/a)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
tensor inicial
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

Esta es una forma de (4, 4) de tensores.

Suma a lo largo del eje 1
  1. x.sum(axis=1, keepdim=True)

Suma a lo largo del eje 1 (es decir, la dirección de la segunda dimensión, 4), manteniendo las dimensiones.

tensor([[ 6],
        [22],
        [38],
        [54]])
  • 1
  • 2
  • 3
  • 4

La forma se vuelve (4, 1)

  1. x.sum(axis=1, keepdim=False)

Suma a lo largo del eje 1, no se conserva ninguna dimensionalidad.

tensor([ 6, 22, 38, 54])
  • 1

La forma se vuelve (4,)

Elementos divididos por suma de filas
  1. x / a
tensor([[0.0000, 0.1667, 0.3333, 0.5000],
        [0.1818, 0.2273, 0.2727, 0.3182],
        [0.2105, 0.2368, 0.2632, 0.2895],
        [0.2222, 0.2407, 0.2593, 0.2778]])
  • 1
  • 2
  • 3
  • 4

Esta es la suma de cada elemento dividida por su fila correspondiente, dando como resultado:

tensor([[ 0/6,  1/6,  2/6,  3/6],
        [ 4/22,  5/22,  6/22,  7/22],
        [ 8/38,  9/38, 10/38, 11/38],
        [12/54, 13/54, 14/54, 15/54]])
  • 1
  • 2
  • 3
  • 4

Resumen de cambios de ejes y dimensiones.

  • eje=0: opere a lo largo de la primera dimensión (fila) y la suma de otras dimensiones permanece después de la suma.
  • eje=1: Opera a lo largo de la segunda dimensión (columna), y la suma de la primera y tercera dimensión permanece después de la suma.
  • eje=2: Opera a lo largo de la tercera dimensión (profundidad) y la suma de las dos primeras dimensiones permanece después de la suma.

usar keepdim=True Al mantener las dimensiones, la dimensión sumada pasa a ser 1.usarkeepdim=False Cuando , se eliminan las dimensiones sumadas.

duda:

¿Por qué las filas se convierten en columnas en lugar de remodelar (1, 4, 4)? ¿Solo cuando remodelar (4, 4) las filas aparecen como filas?

Estructura tensorial

Repasemos primero los conceptos básicos:

  • tensor 2D (matriz): Tiene filas y columnas.
  • tensores 3D: Se compone de múltiples matrices bidimensionales y puede considerarse como una pila de matrices con una dimensión de "profundidad".

Ejemplo 1: tensor bidimensional (4, 4)

import torch
x = torch.arange(16).reshape(4, 4)
print(x)
  • 1
  • 2
  • 3

Producción:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

La forma de este tensor es (4, 4), representa una matriz de 4x4:

  • DE ACUERDOes horizontal:

    • Línea 0: [ 0, 1, 2, 3]
    • Línea 1: [ 4, 5, 6, 7]
    • Línea 2: [ 8, 9, 10, 11]
    • Línea 3: [12, 13, 14, 15]
  • Listaes vertical:

    • Columna 0: [ 0, 4, 8, 12]
    • Columna 1: [ 1, 5, 9, 13]
    • Columna 2: [ 2, 6, 10, 14]
    • Columna 3: [ 3, 7, 11, 15]

Ejemplo 2: tensor tridimensional (1, 4, 4)

x = torch.arange(16).reshape(1, 4, 4)
print(x)
  • 1
  • 2

Producción:

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

La forma de este tensor es (1, 4, 4), representa un tensor tridimensional de 1x4x4:

  • La primera dimensión es 1, indicando el tamaño del lote.
  • La segunda dimensión es 4, representa el número de filas (filas por matriz).
  • La tercera dimensión es 4, representa el número de columnas (columnas de cada matriz).

Explicación de la suma a lo largo del eje.

Suma a lo largo del eje 1 (segunda dimensión)
  1. Para un tensor bidimensional (4, 4)
a = x.sum(axis=1, keepdim=True)
print(a)
  • 1
  • 2

Producción:

tensor([[ 6],
        [22],
        [38],
        [54]])
  • 1
  • 2
  • 3
  • 4
  • El eje 1 es la dirección de las filas y los elementos de cada fila se suman:
    • [0, 1, 2, 3] => 0+1+2+3 = 6
    • [4, 5, 6, 7] => 4+5+6+7 = 22
    • [8, 9, 10, 11] => 8+9+10+11 = 38
    • [12, 13, 14, 15] => 12+13+14+15 = 54
  1. Para tensores tridimensionales (1, 4, 4)
a = x.sum(axis=1, keepdim=True)
print(a)
  • 1
  • 2

Producción:

tensor([[[24, 28, 32, 36]]])
  • 1
  • El eje 1 es la dirección de la fila de la primera matriz. Suma los elementos de cada fila:
    • [0, 1, 2, 3] + [4, 5, 6, 7] + [8, 9, 10, 11] + [12, 13, 14, 15]
    • Suma por columna: 24 = 0+4+8+12, 28 = 1+5+9+13, 32 = 2+6+10+14, 36 = 3+7+11+15

¿Por qué parece que "las filas se convierten en columnas"?

existir (1, 4, 4) En el tensor tridimensional, la primera dimensión representa el tamaño del lote, por lo que parece que cada matriz 4x4 todavía se procesa de manera bidimensional cuando está en funcionamiento. Sin embargo, debido a que se agrega una dimensión por lotes, se comporta de manera diferente a un tensor bidimensional en la operación de suma.

Específicamente:

  • Al sumar a lo largo del eje 1 (la segunda dimensión), estamos sumando las filas de cada matriz.
  • Al sumar a lo largo del eje 0 (la primera dimensión), sumamos las dimensiones del lote.

Resumir

  • tensores 2D: Los conceptos de filas y columnas son intuitivos.
  • tensores 3D: Después de introducir la dimensión del lote, las operaciones de fila y columna se verán diferentes, pero en realidad todavía se operan en cada matriz bidimensional.
  • Al sumar a lo largo de un eje, la clave es comprender las dimensiones de la suma.