प्रौद्योगिकी साझेदारी

आदर्श छंटाई ज्ञानबिन्दून् संकलनम्

2024-07-12

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

आदर्श छंटाई ज्ञानबिन्दून् संकलनम्

छंटनी इतिगहनशिक्षणप्रतिरूपम्अनुकूलनार्थं द्वौ सामान्यौ तकनीकौ मॉडलजटिलतां न्यूनीकर्तुं अनुमानवेगं च सुधारयितुम् उपयुज्यते, संसाधन-संकुचितवातावरणानां कृते च उपयुक्तौ स्तः

छंटनी

छंटाई मॉडल् मध्ये अमहत्त्वपूर्णान् अथवा अनावश्यकमापदण्डान् दूरीकृत्य मॉडलस्य आकारं गणनाप्रयत्नं च न्यूनीकर्तुं पद्धतिः अस्ति । सामान्यतया छटाकरणं निम्नलिखितप्रकारेषु विभक्तं भवति ।

1. भार छंटनी

भारस्य छंटाई भारमात्रिकायां शून्यसमीपस्थानि तत्त्वानि निष्कास्य मॉडलस्य मापदण्डानां संख्यां न्यूनीकरोति । सामान्यविधयः सन्ति : १.

  • असंरचित छंटनी: भारमात्रिकायां लघुभारं एकैकं निष्कासयन्तु।
  • संरचित छंटनी: विशिष्टसंरचनाद्वारा (यथा सम्पूर्णपङ्क्तयः सम्पूर्णस्तम्भाः वा) भारं निष्कासयन्तु ।

उदाहरण:

import torch

# 假设有一个全连接层
fc = torch.nn.Linear(100, 100)

# 获取权重矩阵
weights = fc.weight.data.abs()

# 设定剪枝阈值
threshold = 0.01

# 应用剪枝
mask = weights > threshold
fc.weight.data *= mask
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

2. चैनल छंटनी

मुख्यतया चैनलस्य छंटाई इत्यस्य उपयोगः भवतिconvolutional तंत्रिका जाल , कन्वोल्यूशनल् लेयर इत्यस्मिन् अमहत्त्वपूर्णचैनलम् अपसारयित्वा गणनायाः परिमाणं न्यूनीकरोति । सामान्यविधयः सन्ति : १.

  • महत्त्वस्य आधारेण स्कोरः: प्रत्येकस्य चैनलस्य महत्त्व-अङ्कस्य गणनां कुर्वन्तु तथा च न्यून-अङ्क-युक्तानि चैनलानि निष्कासयन्तु।
  • विरलतायाः आधारेण: विरलनियमनपदानि योजयित्वा प्रशिक्षणप्रक्रियायाः समये केचन चैनलाः स्वाभाविकतया विरलाः भविष्यन्ति ततः छंटनीः भविष्यन्ति।
import torch
import torch.nn as nn

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

model = ConvNet()

# 获取卷积层的权重
weights = model.conv1.weight.data.abs()

# 计算每个通道的L1范数
channel_importance = torch.sum(weights, dim=[1, 2, 3])

# 设定剪枝阈值
threshold = torch.topk(channel_importance, k=32, largest=True).values[-1]

# 应用剪枝
mask = channel_importance > threshold
model.conv1.weight.data *= mask.view(-1, 1, 1, 1)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

3. स्तर छंटनी

लेयर प्रूनिंग् इत्यनेन मॉडलस्य गणनागहनतां न्यूनीकर्तुं सम्पूर्णं जालस्तरं निष्कासितम् । एषः उपायः अधिकं कट्टरपंथी अस्ति, प्रायः Model Architecture Search (NAS) इत्यनेन सह उपयुज्यते ।

import torch.nn as nn

class LayerPrunedNet(nn.Module):
    def __init__(self, use_layer=True):
        super(LayerPrunedNet, self).__init__()
        self.use_layer = use_layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = self.conv1(x)
        if self.use_layer:
            x = self.conv2(x)
        return x

# 初始化网络,选择是否使用第二层
model = LayerPrunedNet(use_layer=False)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18