技術共有

pytorch の軸の理解

2024-07-12

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

import torch
x = torch.arange(16).reshape(1,4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
a3 = x.sum(axis = 0,keepdim=True)
a4 = x.sum(axis = 0,keepdim=False)
a5 = x.sum(axis = 2,keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
 import torch
x = torch.arange(16).reshape(4,4)
print(x)
print('--------')
a = x.sum(axis = 1,keepdim=True)
a2 = x.sum(axis = 1,keepdim=False)
print(a)
print(a2)
print(x/a) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

これら 2 つの例を組み合わせて、さまざまな状況下での軸の変化を詳細に説明します。
PyTorch でのテンソルの次元操作と特定の軸に沿った合計を理解するには、少し時間がかかります。 2 つの例を通してこれらの操作を段階的に分析し、さまざまな状況下での軸の変化を詳細に説明します。

最初の例

import torch
x = torch.arange(16).reshape(1, 4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
a3 = x.sum(axis=0, keepdim=True)
a4 = x.sum(axis=0, keepdim=False)
a5 = x.sum(axis=2, keepdim=True)
print(a)
print(a2)
print('----------')
print(a3)
print(a4)
print(a5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
初期テンソル
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

これは、 (1, 4, 4) テンソルの。これは 4x4 行列を含むバッチと考えることができます。

軸 1 に沿った合計
  1. x.sum(axis=1, keepdim=True)

次元を維持しながら、軸 1 (つまり 2 番目の次元 4 の方向) に沿って合計します。

tensor([[[24, 28, 32, 36]]])
  • 1

形はこうなります (1, 1, 4)

  1. x.sum(axis=1, keepdim=False)

軸 1 に沿って合計し、次元は保持されません。

tensor([[24, 28, 32, 36]])
  • 1

形はこうなります (1, 4)

軸 0 に沿った合計
  1. x.sum(axis=0, keepdim=True)

次元を維持しながら、軸 0 (つまり、最初の次元 1 の方向) に沿って合計します。

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

元のテンソルには軸 0 上の要素が 1 つだけあるため、結果は元のテンソルと同じ形状になります。 (1, 4, 4)

  1. x.sum(axis=0, keepdim=False)

軸 0 に沿った合計。次元は保持されません。

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

形はこうなります (4, 4)

軸 2 に沿った合計
  1. x.sum(axis=2, keepdim=True)

次元を維持しながら、軸 2 (つまり、3 番目の次元、4 の方向) に沿って合計します。

tensor([[[ 6],
         [22],
         [38],
         [54]]])
  • 1
  • 2
  • 3
  • 4

形はこうなります (1, 4, 1)

キーポイント
  • keepdim=True 合計された次元は保持され、結果の次元数は変更されませんが、合計された次元のサイズは 1 になります。
  • keepdim=False 合計された次元が削除され、結果の次元数が 1 減ります。

2番目の例

import torch
x = torch.arange(16).reshape(4, 4)
print(x)
print('--------')
a = x.sum(axis=1, keepdim=True)
a2 = x.sum(axis=1, keepdim=False)
print(a)
print(a2)
print(x/a)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
初期テンソル
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

これは、 (4, 4) テンソルの。

軸 1 に沿った合計
  1. x.sum(axis=1, keepdim=True)

次元を維持しながら、軸 1 (つまり 2 番目の次元 4 の方向) に沿って合計します。

tensor([[ 6],
        [22],
        [38],
        [54]])
  • 1
  • 2
  • 3
  • 4

形はこうなります (4, 1)

  1. x.sum(axis=1, keepdim=False)

軸 1 に沿って合計し、次元は保持されません。

tensor([ 6, 22, 38, 54])
  • 1

形はこうなります (4,)

行の合計で要素を除算
  1. x / a
tensor([[0.0000, 0.1667, 0.3333, 0.5000],
        [0.1818, 0.2273, 0.2727, 0.3182],
        [0.2105, 0.2368, 0.2632, 0.2895],
        [0.2222, 0.2407, 0.2593, 0.2778]])
  • 1
  • 2
  • 3
  • 4

これは、各要素の合計を対応する行で割ったもので、次のようになります。

tensor([[ 0/6,  1/6,  2/6,  3/6],
        [ 4/22,  5/22,  6/22,  7/22],
        [ 8/38,  9/38, 10/38, 11/38],
        [12/54, 13/54, 14/54, 15/54]])
  • 1
  • 2
  • 3
  • 4

軸と寸法の変更の概要

  • 軸=0: 最初の次元 (行) に沿って操作し、他の次元の合計は合計後に残ります。
  • 軸=1: 2 番目の次元 (列) に沿って動作し、合計後に 1 番目と 3 番目の次元の合計が残ります。
  • 軸=2: 3 番目の次元 (深さ) に沿って動作し、最初の 2 つの次元の合計は合計後に残ります。

使用 keepdim=True 次元を維持する場合、次元の合計は 1 になります。使用keepdim=False の場合、合計された寸法が削除されます。

疑い:

reshape(1, 4, 4) ではなく行が列になるのはなぜですか? reshape(4, 4) の場合にのみ行が行として表示されます。

テンソル構造

まず基本的な概念を確認しましょう。

  • 2D テンソル (行列): 行と列があります。
  • 3Dテンソル: 複数の 2 次元行列で構成され、「深さ」次元を持つ行列のスタックとみなすことができます。

例 1: 2 次元テンソル (4, 4)

import torch
x = torch.arange(16).reshape(4, 4)
print(x)
  • 1
  • 2
  • 3

出力:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4

このテンソルの形状は、 (4, 4)、 は 4x4 行列を表します。

  • わかりました水平です:

    • 行0: [ 0, 1, 2, 3]
    • ライン1: [ 4, 5, 6, 7]
    • 2行目: [ 8, 9, 10, 11]
    • 3行目: [12, 13, 14, 15]
  • リスト垂直です:

    • 列 0: [ 0, 4, 8, 12]
    • 列 1: [ 1, 5, 9, 13]
    • 列 2: [ 2, 6, 10, 14]
    • 列 3: [ 3, 7, 11, 15]

例 2: 3 次元テンソル (1, 4, 4)

x = torch.arange(16).reshape(1, 4, 4)
print(x)
  • 1
  • 2

出力:

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4

このテンソルの形状は、 (1, 4, 4)、 1x4x4 の 3 次元テンソルを表します。

  • 最初の次元は 1、バッチサイズを示します。
  • 2 番目の次元は 4、行数 (行列あたりの行数) を表します。
  • 3 番目の次元は、 4、列数(各行列の列)を表します。

軸に沿った合計の説明

軸 1 に沿った合計 (2 番目の次元)
  1. 2次元テンソルの場合 (4, 4)
a = x.sum(axis=1, keepdim=True)
print(a)
  • 1
  • 2

出力:

tensor([[ 6],
        [22],
        [38],
        [54]])
  • 1
  • 2
  • 3
  • 4
  • 軸 1 は行の方向であり、各行の要素が合計されます。
    • [0, 1, 2, 3] => 0+1+2+3 = 6
    • [4, 5, 6, 7] => 4+5+6+7 = 22
    • [8, 9, 10, 11] => 8+9+10+11 = 38
    • [12, 13, 14, 15] => 12+13+14+15 = 54
  1. 3次元テンソルの場合 (1, 4, 4)
a = x.sum(axis=1, keepdim=True)
print(a)
  • 1
  • 2

出力:

tensor([[[24, 28, 32, 36]]])
  • 1
  • 軸 1 は最初の行列の行方向です。各行の要素を合計します。
    • [0, 1, 2, 3] + [4, 5, 6, 7] + [8, 9, 10, 11] + [12, 13, 14, 15]
    • 列ごとの合計: 24 = 0+4+8+12、28 = 1+5+9+13、32 = 2+6+10+14、36 = 3+7+11+15

「行が列になる」ように見えるのはなぜですか

存在する (1, 4, 4) 3次元テンソルでは最初の次元がバッチサイズを表すため、動作時には各4x4行列が依然として2次元的に処理されるようです。ただし、バッチ次元が追加されるため、合計演算では 2 次元テンソルとは異なる動作をします。

具体的には:

  • 軸 1 (2 番目の次元) に沿って合計する場合、各行列の行を合計します。
  • 軸 0 (最初の次元) に沿って合計する場合、バッチの次元を合計します。

要約する

  • 2D テンソル: 行と列の概念は直感的です。
  • 3Dテンソル: バッチ ディメンションの導入後、行と列の操作は異なって見えますが、実際には依然として各 2 次元行列に対して操作されます。
  • 軸に沿って合計する場合、合計の次元を理解することが重要です。