기술나눔

VGG16은 이미지 분류의 pytorch 구현을 구현하고 단계를 자세히 설명합니다.

2024-07-12

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

VGG16은 이미지 분류를 구현합니다.

여기서는 CIFAR 데이터 세트를 분류하기 위해 VGG-16 네트워크를 구현합니다.

VGG16 네트워크 소개

머리말

《대규모 이미지 인식을 위한 매우 깊은 합성 신경망》

ICLR 2015

VGG옥스퍼드 출신이에요V실제G기하측량G 그룹이 제안한 그룹입니다(VGG라는 이름의 유래를 알 수 있을 것입니다). 이 네트워크는 ILSVRC 2014의 관련 작업입니다. 주요 작업은 네트워크의 깊이를 늘리는 것이 네트워크의 최종 성능에 어느 정도 영향을 미칠 수 있음을 증명하는 것입니다. VGG에는 VGG16과 VGG19라는 두 가지 구조가 있습니다. 둘 사이에는 본질적인 차이가 없지만 네트워크 깊이가 다릅니다.

VGG 원리

AlexNet에 비해 VGG16의 개선점은 다음과 같습니다.여러 개의 연속적인 3x3 컨볼루션 커널을 사용하여 AlexNet(11x11, 7x7, 5x5)의 더 큰 컨볼루션 커널을 대체합니다. . 주어진 수용 필드(출력에 대한 입력 이미지의 로컬 크기)의 경우 여러 개의 비선형 레이어가 네트워크 깊이를 증가시켜 더 복잡한 학습 모드를 보장할 수 있기 때문에 누적된 작은 컨볼루션 커널을 사용하는 것이 더 좋습니다. 비용은 상대적으로 적습니다(매개변수 수가 적음).

간단히 말하면, VGG에서는 7x7 컨볼루션 커널을 대체하기 위해 3x3 컨볼루션 커널을 사용하고, 5*5 컨볼루션 커널을 대체하기 위해 3x3 컨볼루션 커널을 2개 사용하는 것이 주요 목적입니다. 수용 필드가 향상되면 네트워크의 깊이가 향상되고 신경망의 효과가 어느 정도 향상됩니다.

예를 들어, 스트라이드가 1인 3개의 3x3 컨볼루션 커널의 레이어별 중첩은 크기 7의 수용 필드로 간주될 수 있습니다(실제로 이는 3개의 3x3 연속 컨볼루션이 7x7 컨볼루션과 동일함을 의미합니다). 전체 매개변수는 3x(9xC^2)입니다. 7x7 컨볼루션 커널을 직접 사용하는 경우 전체 매개변수 수는 49xC^2입니다. 여기서 C는 입력 및 출력 채널 수를 나타냅니다.당연히 27xC2 49xC 미만2, 즉 매개변수가 감소하고 3x3 컨볼루션 커널이 이미지 속성을 더 잘 유지하는 데 도움이 됩니다.

다음은 5*5 컨볼루션 커널 대신 두 개의 3x3 컨볼루션 커널을 사용할 수 있는 이유에 대한 설명입니다.

5x5 컨볼루션은 5x5 영역에서 미끄러지는 작은 완전 연결 네트워크로 간주됩니다. 먼저 3x3 컨볼루션 필터를 사용하여 완전 연결 레이어를 사용하여 3x3 컨볼루션 출력을 연결할 수도 있습니다. 3x3 컨볼루션 레이어로 볼 수 있습니다. 이런 방식으로 우리는 하나의 5x5 컨볼루션 대신 두 개의 3x3 컨볼루션을 계단식으로 배열(중첩)할 수 있습니다.

자세한 내용은 아래 그림에 나와 있습니다.

여기에 이미지 설명을 삽입하세요.

VGG 네트워크 구조

여기에 이미지 설명을 삽입하세요.

다음은 VGG 네트워크의 구조입니다(VGG16 및 VGG19가 모두 존재함).

여기에 이미지 설명을 삽입하세요.

GG 네트워크 구조

VGG16에는 위 그림의 D열에 표시된 대로 16개의 숨겨진 레이어(13개의 컨벌루션 레이어와 3개의 완전 연결 레이어)가 포함되어 있습니다.

VGG19에는 위 그림의 E열에 표시된 대로 19개의 숨겨진 레이어(16개의 컨벌루션 레이어와 3개의 완전 연결 레이어)가 포함되어 있습니다.

VGG 네트워크의 구조는 처음부터 끝까지 3x3 컨볼루션과 2x2 최대 풀링을 사용하여 매우 일관됩니다.

VGG의 장점

VGGNet의 구조는 매우 간단합니다. 전체 네트워크는 동일한 컨볼루션 커널 크기(3x3)와 최대 풀링 크기(2x2)를 사용합니다.

여러 개의 작은 필터(3x3) 컨벌루션 레이어 조합이 하나의 큰 필터(5x5 또는 7x7) 컨벌루션 레이어보다 낫습니다.

네트워크 구조를 지속적으로 심화시켜 성능을 향상시킬 수 있음이 검증되었습니다.

VGG의 단점

VGG는 더 많은 컴퓨팅 리소스를 소비하고 더 많은 매개변수를 사용하므로(3x3 컨볼루션의 포트가 아님) 더 많은 메모리 사용량(140M)이 발생합니다.

데이터 세트 처리

데이터세트 소개

CIFAR(Canadian Institute For Advanced Research) 데이터 세트는 컴퓨터 비전 분야에서 널리 사용되는 작은 이미지 데이터 세트로, 특히 이미지 인식 및 분류와 같은 작업에서 기계 학습 및 컴퓨터 비전 알고리즘을 훈련하는 데 주로 사용됩니다. CIFAR 데이터 세트는 CIFAR-10과 CIFAR-100의 두 가지 주요 부분으로 구성됩니다.

CIFAR-10은 60,000개의 32x32 컬러 이미지를 포함하는 데이터세트로, 각 카테고리에는 6,000개의 이미지가 포함된 10개의 카테고리로 나뉩니다. 10개 카테고리는 비행기, 자동차, 새, 고양이, 사슴, 개, 개구리, 말, 보트, 트럭입니다. 데이터세트에서는 50,000개의 이미지가 훈련에 사용되고 10,000개의 이미지가 테스트에 사용됩니다. CIFAR-10 데이터 세트는 적당한 크기와 풍부한 클래스 정보로 인해 컴퓨터 비전 분야의 연구 및 교육에서 매우 인기 있는 데이터 세트 중 하나가 되었습니다.

데이터 세트 특성
  • 중간 사이즈: CIFAR 데이터 세트(32x32)의 작은 이미지 크기로 인해 새로운 컴퓨터 비전 알고리즘을 빠르게 훈련하고 테스트하는 데 이상적입니다.
  • 다양한 카테고리: CIFAR-10은 기본적인 이미지 분류 작업을 제공하는 반면, CIFAR-100은 알고리즘의 세분화된 분류 기능에 더욱 도전합니다.
  • 광대하게 사용 된: 이러한 특성으로 인해 CIFAR 데이터 세트는 컴퓨터 비전, 기계 학습, 딥 러닝 및 기타 분야의 연구 및 교육에 널리 사용됩니다.
사용되는 장면

CIFAR 데이터 세트는 일반적으로 이미지 분류, 객체 인식, CNN(컨볼루션 신경망) 훈련 및 테스트와 같은 작업에 사용됩니다. 적당한 크기와 풍부한 카테고리 정보로 인해 이미지 인식 알고리즘을 탐구하는 초보자와 연구자에게 이상적입니다. 또한 많은 컴퓨터 비전 및 기계 학습 대회에서는 참가자의 알고리즘 성능을 평가하기 위한 벤치마크로 CIFAR 데이터 세트를 사용합니다.

데이터 세트를 준비하기 위해 이미 다운로드했습니다. 작동하지 않으면 공식 웹 사이트에서 다운로드하거나 직접 제공해드립니다.

데이터 세트가 필요한 경우 이메일로 문의하십시오: [email protected]

내 데이터 세트는 원래 토치비전에서 다운로드한 데이터를 통해 생성되었습니다. 지금은 데이터 세트의 정의와 DataLoader의 로딩을 단계별로 구현하고 이 프로세스를 이해하고 싶습니다. 데이터 세트 처리 과정을 통해 딥 러닝에 대해 더욱 깊이 있게 이해할 수 있습니다.

데이터 세트 스타일은 다음과 같습니다.

여기에 이미지 설명을 삽입하세요.

데이터세트의 모든 라벨을 구문 분석합니다.

데이터세트의 라벨 카테고리는.meta파일이 저장되었으므로 구문 분석이 필요합니다. .meta 모든 태그 데이터를 읽는 파일입니다. 파싱 ​​코드는 다음과 같습니다.

# 首先了解所有的标签,TODO 可以详细了解一下这个解包的过程
import pickle


def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


meta_data = unpickle('./dataset_method_1/cifar-10-batches-py/batches.meta')
label_names = meta_data[b'label_names']
# 将字节标签转换为字符串
label_names = [label.decode('utf-8') for label in label_names]
print(label_names)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

분석 결과는 다음과 같습니다.

['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  • 1

간단한 데이터 테스트를 위해 단일 데이터 배치 로드

데이터 세트가 다운로드되었으므로 파일의 내용을 읽어야 합니다. 파일이 바이너리 파일이므로 읽으려면 바이너리 읽기 모드를 사용해야 합니다.

읽는 코드는 다음과 같습니다.

# 载入单个批次的数据
import numpy as np


def load_data_batch(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
        X = dict[b'data']
        Y = dict[b'labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1)  # reshape and transpose to (10000, 32, 32, 3)
        Y = np.array(Y)
    return X, Y


# 加载第一个数据批次
data_batch_1 = './dataset_method_1/cifar-10-batches-py/data_batch_1'
X1, Y1 = load_data_batch(data_batch_1)

print(f'数据形状: {X1.shape}, 标签形状: {Y1.shape}')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

결과:

数据形状: (10000, 32, 32, 3), 标签形状: (10000,)
  • 1

모든 데이터 로드

위의 테스트를 마친 후 데이터를 로드하는 방법을 알았습니다. 이제 모든 데이터를 로드해 보겠습니다.

훈련 세트 로드:

# 整合所有批次的数据
def load_all_data_batches(batch_files):
    X_list, Y_list = [], []
    for file in batch_files:
        X, Y = load_data_batch(file)
        X_list.append(X)
        Y_list.append(Y)
    X_all = np.concatenate(X_list)
    Y_all = np.concatenate(Y_list)
    return X_all, Y_all


batch_files = [
    './dataset_method_1/cifar-10-batches-py/data_batch_1',
    './dataset_method_1/cifar-10-batches-py/data_batch_2',
    './dataset_method_1/cifar-10-batches-py/data_batch_3',
    './dataset_method_1/cifar-10-batches-py/data_batch_4',
    './dataset_method_1/cifar-10-batches-py/data_batch_5'
]

X_train, Y_train = load_all_data_batches(batch_files)
print(f'训练数据形状: {X_train.shape}, 训练标签形状: {Y_train.shape}')
Y_train = Y_train.astype(np.int64)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

산출:

训练数据形状: (50000, 32, 32, 3), 训练标签形状: (50000,)
  • 1

테스트 세트 로드:

test_batch = './dataset_method_1/cifar-10-batches-py/test_batch'
X_test, Y_test = load_data_batch(test_batch)
Y_test = Y_test.astype(np.int64)
print(f'测试数据形状: {X_test.shape}, 测试标签形状: {Y_test.shape}')

  • 1
  • 2
  • 3
  • 4
  • 5

산출:

测试数据形状: (10000, 32, 32, 3), 测试标签形状: (10000,)
  • 1

데이터 세트의 하위 클래스 정의

Dataset 클래스의 하위 클래스를 정의하는 것은 일괄 훈련을 위해 Dataloader의 후속 로드를 용이하게 하기 위한 것입니다.

Dataset의 하위 클래스가 구현해야 하는 세 가지 메서드가 있습니다.

  • __init__()클래스 생성자
  • __len__()데이터 세트의 길이를 반환합니다.
  • __getitem__()데이터 세트에서 데이터 조각 가져오기

여기서 내 구현은 다음과 같습니다.

from torch.utils.data import DataLoader, Dataset


# 定义 Pytorch 的数据集 
class CIFARDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

데이터세트를 Dataloader로 로드

  1. 데이터를 향상하기 위한 변환을 정의합니다. 먼저 훈련 세트를 4px로 확장하고, 정규화하고, 수평으로 뒤집고, 회색조로 처리한 다음, 마지막으로 32 * 32의 원래 픽셀로 반환해야 합니다.
transform_train = transforms.Compose(
    [transforms.Pad(4),
     transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
     transforms.RandomHorizontalFlip(),
     transforms.RandomGrayscale(),
     transforms.RandomCrop(32, padding=4),
     ])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  1. 여기에는 이미지 처리가 포함되고 바이너리 파일에서 읽는 데이터는 numpy 데이터이므로 이미지 처리를 용이하게 하려면 numpy 배열을 이미지 데이터로 변환해야 합니다. 다음과 같이 처리합니다.
# 把数据集变成 Image 的数组,不然好像不能进行数据的增强
# 改变训练数据
from PIL import Image
def get_PIL_Images(origin_data):
    datas = []
    for i in range(len(origin_data)):
        data = Image.fromarray(origin_data[i])
        datas.append(data)
    return datas
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  1. 학습된 데이터로더 가져오기
train_data = get_PIL_Images(X_train)
train_loader = DataLoader(CIFARDataset(train_data, Y_train, transform_train), batch_size=24, shuffle=True)
  • 1
  • 2
  1. 테스트 데이터로더 테스트 세트를 얻는 데는 많은 처리가 필요하지 않습니다. 코드는 여기에 직접 제공됩니다.
# 测试集的预处理
transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
)
test_loader = DataLoader(CIFARDataset(X_test, Y_test, transform_test), batch_size=24, shuffle=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

네트워크 정의

위에서 언급한 VGG16 네트워크를 기반으로 Pytorch 프레임워크를 구현합니다.

주로 나누어진다:

  • 컨볼루션 레이어
  • 완전 연결 레이어
  • 분류층

구현은 다음과 같습니다.

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        # 卷积层,这里进行卷积
        self.convolusion = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=3, padding=1), # 设置为padding=1 卷积完后,数据大小不会变
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 96, kernel_size=3, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 96, kernel_size=3, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(96, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AvgPool2d(kernel_size=1, stride=1)
        )
        # 全连接层
        self.dense = nn.Sequential(
            nn.Linear(512, 4096), # 32*32 的图像大小经过 5 次最大化池化后就只有 1*1 了,所以就是 512 个通道的数据输入全连接层
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
        )
        # 输出层
        self.classifier = nn.Linear(4096, 10)

    def forward(self, x):
        out = self.convolusion(x)
        out = out.view(out.size(0), -1)
        out = self.dense(out)
        out = self.classifier(out)
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75

훈련과 테스트

학습 및 테스트를 위해서는 모델을 인스턴스화한 다음 최적화 함수, 손실 함수 및 손실률을 정의한 다음 학습 및 테스트를 수행하면 됩니다.

코드는 아래와 같이 표시됩니다.

초매개변수 정의:

# 定义模型进行训练
model = VGG16()
# model.load_state_dict(torch.load('./my-VGG16.pth'))
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=5e-3)
loss_func = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

테스트 기능:

def test():
    model.eval()
    correct = 0  # 预测正确的图片数
    total = 0  # 总共的图片数
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            outputs = model(images).to(device)
            outputs = outputs.cpu()
            outputarr = outputs.numpy()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
    accuracy = 100 * correct / total
    accuracy_rate.append(accuracy)
    print(f'准确率为:{accuracy}%'.format(accuracy))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

훈련 시대:

# 定义训练步骤
total_times = 40
total = 0
accuracy_rate = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for epoch in range(total_times):
    model.train()
    model.to(device)
    running_loss = 0.0
    total_correct = 0
    total_trainset = 0
    print("epoch: ",epoch)
    for i, (data,labels) in enumerate(train_loader):
        data = data.to(device)
        outputs = model(data).to(device)
        labels = labels.to(device)
        loss = loss_func(outputs,labels).to(device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _,pred = outputs.max(1)
        correct = (pred == labels).sum().item()
        total_correct += correct
        total_trainset += data.shape[0]
        if i % 100 == 0 and i > 0:
            print(f"正在进行第{i}次训练, running_loss={running_loss}".format(i, running_loss))
            running_loss = 0.0
    test()
    scheduler.step()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

훈련된 모델을 저장합니다.

torch.save(model.state_dict(), './my-VGG16.pth')
accuracy_rate = np.array(accuracy_rate)
times = np.linspace(1, total_times, total_times)
plt.xlabel('times')
plt.ylabel('accuracy rate')
plt.plot(times, accuracy_rate)
plt.show()
print(accuracy_rate)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

시험

  1. 모델 정의
model_my_vgg = VGG16()
  • 1
  1. 훈련된 모델 데이터 추가
model_my_vgg.load_state_dict(torch.load('./my-VGG16-best.pth',map_location='cpu'))
  • 1
  1. 내가 직접 찾은 인증이미지 처리 중
from torchvision import transforms
from PIL import Image

# 定义图像预处理步骤
preprocess = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

def load_image(image_path):
    image = Image.open(image_path)
    image = preprocess(image)
    image = image.unsqueeze(0)  # 添加批次维度
    return image

image_data = load_image('./plane2.jpg')
print(image_data.shape)
output = model_my_vgg(image_data)
verify_data = X1[9]
verify_label = Y1[9]
output_verify = model_my_vgg(transform_test(verify_data).unsqueeze(0))
print(output)
print(output_verify)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

산출:

torch.Size([1, 3, 32, 32])
tensor([[ 1.5990, -0.5269,  0.7254,  0.3432, -0.5036, -0.3267, -0.5302, -0.9417,
          0.4186, -0.1213]], grad_fn=<AddmmBackward0>)
tensor([[-0.6541, -2.0759,  0.6308,  1.9791,  0.8525,  1.2313,  0.1856,  0.3243,
         -1.3374, -1.0211]], grad_fn=<AddmmBackward0>)
  • 1
  • 2
  • 3
  • 4
  • 5
  1. 결과 인쇄
print(label_names[torch.argmax(output,dim=1,keepdim=False)])
print(label_names[verify_label])
print("pred:",label_names[torch.argmax(output_verify,dim=1,keepdim=False)])
  • 1
  • 2
  • 3
airplane
cat
pred: cat
  • 1
  • 2
  • 3

여기에 이미지 설명을 삽입하세요.

말 확인

여기에 이미지 설명을 삽입하세요.

검증견

여기에 이미지 설명을 삽입하세요.