技術共有

[PyTorch] torch.fmod は切り捨て正規分布を使用してニューラル ネットワークの重みを初期化します

2024-07-11

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

このコード スニペットは、PyTorch を使用して、切り捨てられた正規分布を使用してニューラル ネットワークの重みを初期化する方法を示しています。切り捨てられた正規分布とは、生成された値が極端な値を避けるために範囲内で切り捨てられることを意味します。ここで使われている torch.fmod この効果を達成するための回避策として。

詳しい説明

1. 切り捨てられた正規分布

切り捨て正規分布は、生成された値が特定の範囲内に収まるように正規分布を修正したものです。具体的には、torch.fmod この関数は、入力テンソルを 2 で割った余りを返します (つまり、結果の値が -2 と 2 の間になるように)。

2. 重みの初期化

コードでは、4 つの重みテンソルが異なる標準偏差に従って計算されます (init_sd_first, init_sd_middle, init_sd_last ) は切り捨てられた正規分布から生成されます。具体的な寸法は次のとおりです。

  • 最初の層の重みテンソルの形状は次のとおりです。 (x_dim, width n_double)
  • 中間層の 2 つの重みテンソルの形状は次のとおりです。 (width, width n_double)
  • 最後の層の重みテンソルの形状は次のとおりです。 (width, 1)

これらの重みテンソルは次のように生成されます。

initial_weights = [
    torch.fmod(torch.normal(0, init_sd_first, size=(x_dim, width   n_double)), 2),
    torch.fmod(torch.normal(0, init_sd_middle, size=(width, width   n_double)), 2),
    torch.fmod(torch.normal(0, init_sd_middle, size=(width, width   n_double)), 2),
    torch.fmod(torch.normal(0, init_sd_last, size=(width, 1)), 2)
]