Partage de technologie

VGG16 implémente l'implémentation pytorch de la classification d'images et explique les étapes en détail

2024-07-12

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

VGG16 implémente la classification des images

Ici, nous implémentons un réseau VGG-16 pour classer l'ensemble de données CIFAR

Présentation du réseau VGG16

Préface

« Réseaux convolutionnels très profonds pour la reconnaissance d'images à grande échelle »

ICLR 2015

VGGÇa vient d'OxfordVisuelgéométrieg Proposé par le groupe du groupe (vous devriez pouvoir voir l'origine du nom VGG). Ce réseau est un travail lié à l'ILSVRC 2014. Le travail principal est de prouver que l'augmentation de la profondeur du réseau peut affecter dans une certaine mesure les performances finales du réseau. VGG a deux structures, à savoir VGG16 et VGG19. Il n'y a pas de différence essentielle entre les deux, mais la profondeur du réseau est différente.

Principe VGG

Une amélioration de VGG16 par rapport à AlexNet est queUtilisez plusieurs noyaux de convolution 3x3 consécutifs pour remplacer les noyaux de convolution plus gros dans AlexNet (11x11, 7x7, 5x5) . Pour un champ récepteur donné (la taille locale de l'image d'entrée par rapport à la sortie), il est préférable d'utiliser de petits noyaux de convolution empilés plutôt que de grands noyaux de convolution, car plusieurs couches non linéaires peuvent augmenter la profondeur du réseau pour garantir un mode d'apprentissage plus complexe. le coût est relativement faible (moins de paramètres).

Pour faire simple, dans VGG, trois noyaux de convolution 3x3 sont utilisés pour remplacer le noyau de convolution 7x7, et deux noyaux de convolution 3x3 sont utilisés pour remplacer le noyau de convolution 5*5. L'objectif principal est d'assurer la même condition. Du champ récepteur, la profondeur du réseau est améliorée et l'effet du réseau neuronal est amélioré dans une certaine mesure.

Par exemple, la superposition couche par couche de trois noyaux de convolution 3x3 avec une foulée de 1 peut être considérée comme un champ récepteur de taille 7 (en fait, cela signifie que trois convolutions continues 3x3 sont équivalentes à une convolution 7x7), et son le nombre total de paramètres est de 3x(9xC^2). Si le noyau de convolution 7x7 est utilisé directement, le nombre total de paramètres est de 49xC^2, où C fait référence au nombre de canaux d'entrée et de sortie.Évidemment, 27xC2 moins de 49xC2, c'est-à-dire que les paramètres sont réduits ; et le noyau de convolution 3x3 permet de mieux conserver les propriétés de l'image.

Voici une explication de la raison pour laquelle deux noyaux de convolution 3x3 peuvent être utilisés au lieu de noyaux de convolution 5*5 :

La convolution 5x5 est considérée comme un petit réseau entièrement connecté glissant dans la zone 5x5. Nous pouvons d'abord convoluer avec un filtre de convolution 3x3, puis utiliser une couche entièrement connectée pour connecter la sortie de convolution 3x3. être considéré comme une couche convolutive 3x3. De cette façon, nous pouvons cascader (superposer) deux convolutions 3x3 au lieu d'une convolution 5x5.

Les détails sont présentés dans la figure ci-dessous :

Insérer la description de l'image ici

Structure du réseau VGG

Insérer la description de l'image ici

Voici la structure du réseau VGG (VGG16 et VGG19 sont présents) :

Insérer la description de l'image ici

Structure du réseau GG

VGG16 contient 16 couches cachées (13 couches convolutives et 3 couches entièrement connectées), comme indiqué dans la colonne D de la figure ci-dessus

VGG19 contient 19 couches cachées (16 couches convolutives et 3 couches entièrement connectées), comme indiqué dans la colonne E de la figure ci-dessus

La structure du réseau VGG est très cohérente, utilisant une convolution 3x3 et un pooling maximum 2x2 du début à la fin.

Avantages du VGG

La structure de VGGNet est très simple. L'ensemble du réseau utilise la même taille de noyau de convolution (3x3) et la même taille de pool maximale (2x2).

La combinaison de plusieurs couches convolutives de petit filtre (3x3) est meilleure qu'une couche convolutive de grand filtre (5x5 ou 7x7) :

Il est vérifié que les performances peuvent être améliorées en approfondissant continuellement la structure du réseau.

Inconvénients du VGG

VGG consomme plus de ressources informatiques et utilise plus de paramètres (ce n'est pas le pot de convolution 3x3), ce qui entraîne une utilisation plus importante de la mémoire (140 Mo).

Traitement des ensembles de données

Introduction à l'ensemble de données

L'ensemble de données CIFAR (Institut canadien de recherches avancées) est un petit ensemble de données d'images largement utilisé dans le domaine de la vision par ordinateur. Il est principalement utilisé pour la formation d'algorithmes d'apprentissage automatique et de vision par ordinateur, notamment dans des tâches telles que la reconnaissance et la classification d'images. L'ensemble de données CIFAR se compose de deux parties principales : CIFAR-10 et CIFAR-100.

CIFAR-10 est un ensemble de données contenant 60 000 images couleur 32x32, divisées en 10 catégories, chaque catégorie contenant 6 000 images. Les 10 catégories sont : les avions, les voitures, les oiseaux, les chats, les cerfs, les chiens, les grenouilles, les chevaux, les bateaux et les camions. Dans l'ensemble de données, 50 000 images sont utilisées pour la formation et 10 000 images sont utilisées pour les tests. L'ensemble de données CIFAR-10 est devenu l'un des ensembles de données les plus populaires en recherche et en enseignement dans le domaine de la vision par ordinateur en raison de sa taille modérée et de la richesse de ses informations sur les classes.

Caractéristiques des ensembles de données
  • taille moyenne: La petite taille des images de l'ensemble de données CIFAR (32 x 32) les rend idéales pour former et tester rapidement de nouveaux algorithmes de vision par ordinateur.
  • Diverses catégories: CIFAR-10 fournit des tâches de classification d'images de base, tandis que CIFAR-100 remet en question les capacités de classification fine de l'algorithme.
  • largement utilisé: En raison de ces caractéristiques, l’ensemble de données du CIFAR est largement utilisé dans la recherche et l’enseignement en vision par ordinateur, en apprentissage automatique, en apprentissage profond et dans d’autres domaines.
scènes à utiliser

L'ensemble de données du CIFAR est couramment utilisé pour des tâches telles que la classification d'images, la reconnaissance d'objets ainsi que la formation et les tests de réseaux neuronaux convolutifs (CNN). En raison de sa taille modérée et de ses riches informations sur les catégories, il est idéal pour les débutants et les chercheurs explorant les algorithmes de reconnaissance d'images. De plus, de nombreux concours de vision par ordinateur et d'apprentissage automatique utilisent également l'ensemble de données du CIFAR comme référence pour évaluer la performance des algorithmes des candidats.

Pour préparer l'ensemble de données, je l'ai déjà téléchargé. Si cela ne fonctionne pas, téléchargez-le simplement depuis le site officiel, ou je vous le donnerai directement.

Si vous avez besoin de l'ensemble de données, veuillez contacter l'e-mail : [email protected]

Mon ensemble de données a été initialement généré via les données téléchargées dans torchvision. Je ne veux pas vraiment faire cela maintenant, je veux implémenter la définition de l'ensemble de données et le chargement du DataLoader étape par étape, comprendre ce processus et comprendre. le processus de traitement des ensembles de données peut vous permettre d'approfondir l'apprentissage en profondeur.

Le style de l'ensemble de données est le suivant :

Insérer la description de l'image ici

Analyser toutes les étiquettes de l'ensemble de données

La catégorie d'étiquette de l'ensemble de données utilise un.metaLe fichier est stocké, nous devons donc l'analyser .meta fichier pour lire toutes les données de balise. Le code d'analyse est le suivant :

# 首先了解所有的标签,TODO 可以详细了解一下这个解包的过程
import pickle


def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


meta_data = unpickle('./dataset_method_1/cifar-10-batches-py/batches.meta')
label_names = meta_data[b'label_names']
# 将字节标签转换为字符串
label_names = [label.decode('utf-8') for label in label_names]
print(label_names)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

Les résultats de l'analyse sont les suivants :

['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  • 1

Chargez un seul lot de données pour tester simplement les données

Notre ensemble de données a été téléchargé, nous devons donc lire le contenu du fichier puisque le fichier est un fichier binaire, nous devons utiliser le mode de lecture binaire pour le lire.

Le code de lecture est le suivant :

# 载入单个批次的数据
import numpy as np


def load_data_batch(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
        X = dict[b'data']
        Y = dict[b'labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1)  # reshape and transpose to (10000, 32, 32, 3)
        Y = np.array(Y)
    return X, Y


# 加载第一个数据批次
data_batch_1 = './dataset_method_1/cifar-10-batches-py/data_batch_1'
X1, Y1 = load_data_batch(data_batch_1)

print(f'数据形状: {X1.shape}, 标签形状: {Y1.shape}')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

résultat:

数据形状: (10000, 32, 32, 3), 标签形状: (10000,)
  • 1

Charger toutes les données

Après le test ci-dessus, nous savons comment charger les données. Chargeons maintenant toutes les données.

Chargement de l'ensemble de formation :

# 整合所有批次的数据
def load_all_data_batches(batch_files):
    X_list, Y_list = [], []
    for file in batch_files:
        X, Y = load_data_batch(file)
        X_list.append(X)
        Y_list.append(Y)
    X_all = np.concatenate(X_list)
    Y_all = np.concatenate(Y_list)
    return X_all, Y_all


batch_files = [
    './dataset_method_1/cifar-10-batches-py/data_batch_1',
    './dataset_method_1/cifar-10-batches-py/data_batch_2',
    './dataset_method_1/cifar-10-batches-py/data_batch_3',
    './dataset_method_1/cifar-10-batches-py/data_batch_4',
    './dataset_method_1/cifar-10-batches-py/data_batch_5'
]

X_train, Y_train = load_all_data_batches(batch_files)
print(f'训练数据形状: {X_train.shape}, 训练标签形状: {Y_train.shape}')
Y_train = Y_train.astype(np.int64)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

Sortir:

训练数据形状: (50000, 32, 32, 3), 训练标签形状: (50000,)
  • 1

Chargement de l'ensemble de test :

test_batch = './dataset_method_1/cifar-10-batches-py/test_batch'
X_test, Y_test = load_data_batch(test_batch)
Y_test = Y_test.astype(np.int64)
print(f'测试数据形状: {X_test.shape}, 测试标签形状: {Y_test.shape}')

  • 1
  • 2
  • 3
  • 4
  • 5

Sortir:

测试数据形状: (10000, 32, 32, 3), 测试标签形状: (10000,)
  • 1

Définir une sous-classe de Dataset

Définir une sous-classe de la classe Dataset vise à faciliter le chargement ultérieur du Dataloader pour la formation par lots.

Il existe trois méthodes que les sous-classes de Dataset doivent implémenter.

  • __init__()constructeur de classe
  • __len__()Renvoie la longueur de l'ensemble de données
  • __getitem__()Récupérer une donnée de l'ensemble de données

Ici, ma mise en œuvre est la suivante :

from torch.utils.data import DataLoader, Dataset


# 定义 Pytorch 的数据集 
class CIFARDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

Charger l'ensemble de données en tant que Dataloader

  1. Définissez une transformation pour améliorer les données. Voici d'abord l'ensemble d'entraînement. L'ensemble d'entraînement doit être élargi de 4 px, normalisé, retourné horizontalement, traité en niveaux de gris et enfin revenu aux pixels d'origine de 32 * 32.
transform_train = transforms.Compose(
    [transforms.Pad(4),
     transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
     transforms.RandomHorizontalFlip(),
     transforms.RandomGrayscale(),
     transforms.RandomCrop(32, padding=4),
     ])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  1. Parce que cela implique un traitement d'image et que les données que nous lisons à partir du fichier binaire sont des données numpy, nous devons convertir le tableau numpy en données image pour faciliter le traitement de l'image. Procédez comme suit :
# 把数据集变成 Image 的数组,不然好像不能进行数据的增强
# 改变训练数据
from PIL import Image
def get_PIL_Images(origin_data):
    datas = []
    for i in range(len(origin_data)):
        data = Image.fromarray(origin_data[i])
        datas.append(data)
    return datas
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  1. Obtenez le chargeur de données formé
train_data = get_PIL_Images(X_train)
train_loader = DataLoader(CIFARDataset(train_data, Y_train, transform_train), batch_size=24, shuffle=True)
  • 1
  • 2
  1. L'obtention de l'ensemble de test du chargeur de données de test ne nécessite pas trop de traitement. Le code est donné directement ici.
# 测试集的预处理
transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
)
test_loader = DataLoader(CIFARDataset(X_test, Y_test, transform_test), batch_size=24, shuffle=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

Définir le réseau

Nous implémentons le framework Pytorch basé sur le réseau VGG16 mentionné ci-dessus.

principalement divisé:

  • couche de convolution
  • Couche entièrement connectée
  • couche de classification

La mise en œuvre est la suivante :

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        # 卷积层,这里进行卷积
        self.convolusion = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=3, padding=1), # 设置为padding=1 卷积完后,数据大小不会变
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 96, kernel_size=3, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 96, kernel_size=3, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(96, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AvgPool2d(kernel_size=1, stride=1)
        )
        # 全连接层
        self.dense = nn.Sequential(
            nn.Linear(512, 4096), # 32*32 的图像大小经过 5 次最大化池化后就只有 1*1 了,所以就是 512 个通道的数据输入全连接层
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
        )
        # 输出层
        self.classifier = nn.Linear(4096, 10)

    def forward(self, x):
        out = self.convolusion(x)
        out = out.view(out.size(0), -1)
        out = self.dense(out)
        out = self.classifier(out)
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75

formation et tests

Pour la formation et les tests, il vous suffit d'instancier le modèle, puis de définir la fonction d'optimisation, la fonction de perte et le taux de perte, puis d'effectuer la formation et les tests.

le code s'affiche comme ci-dessous :

Définition de l'hyperparamètre :

# 定义模型进行训练
model = VGG16()
# model.load_state_dict(torch.load('./my-VGG16.pth'))
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=5e-3)
loss_func = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

Fonction de test :

def test():
    model.eval()
    correct = 0  # 预测正确的图片数
    total = 0  # 总共的图片数
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            outputs = model(images).to(device)
            outputs = outputs.cpu()
            outputarr = outputs.numpy()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
    accuracy = 100 * correct / total
    accuracy_rate.append(accuracy)
    print(f'准确率为:{accuracy}%'.format(accuracy))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

Époques de formation :

# 定义训练步骤
total_times = 40
total = 0
accuracy_rate = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for epoch in range(total_times):
    model.train()
    model.to(device)
    running_loss = 0.0
    total_correct = 0
    total_trainset = 0
    print("epoch: ",epoch)
    for i, (data,labels) in enumerate(train_loader):
        data = data.to(device)
        outputs = model(data).to(device)
        labels = labels.to(device)
        loss = loss_func(outputs,labels).to(device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _,pred = outputs.max(1)
        correct = (pred == labels).sum().item()
        total_correct += correct
        total_trainset += data.shape[0]
        if i % 100 == 0 and i > 0:
            print(f"正在进行第{i}次训练, running_loss={running_loss}".format(i, running_loss))
            running_loss = 0.0
    test()
    scheduler.step()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

Enregistrez le modèle entraîné :

torch.save(model.state_dict(), './my-VGG16.pth')
accuracy_rate = np.array(accuracy_rate)
times = np.linspace(1, total_times, total_times)
plt.xlabel('times')
plt.ylabel('accuracy rate')
plt.plot(times, accuracy_rate)
plt.show()
print(accuracy_rate)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

test

  1. Définir le modèle
model_my_vgg = VGG16()
  • 1
  1. Ajouter des données de modèle entraîné
model_my_vgg.load_state_dict(torch.load('./my-VGG16-best.pth',map_location='cpu'))
  • 1
  1. Traitement des images de vérification que je me suis retrouvé
from torchvision import transforms
from PIL import Image

# 定义图像预处理步骤
preprocess = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

def load_image(image_path):
    image = Image.open(image_path)
    image = preprocess(image)
    image = image.unsqueeze(0)  # 添加批次维度
    return image

image_data = load_image('./plane2.jpg')
print(image_data.shape)
output = model_my_vgg(image_data)
verify_data = X1[9]
verify_label = Y1[9]
output_verify = model_my_vgg(transform_test(verify_data).unsqueeze(0))
print(output)
print(output_verify)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

Sortir:

torch.Size([1, 3, 32, 32])
tensor([[ 1.5990, -0.5269,  0.7254,  0.3432, -0.5036, -0.3267, -0.5302, -0.9417,
          0.4186, -0.1213]], grad_fn=<AddmmBackward0>)
tensor([[-0.6541, -2.0759,  0.6308,  1.9791,  0.8525,  1.2313,  0.1856,  0.3243,
         -1.3374, -1.0211]], grad_fn=<AddmmBackward0>)
  • 1
  • 2
  • 3
  • 4
  • 5
  1. Imprimer les résultats
print(label_names[torch.argmax(output,dim=1,keepdim=False)])
print(label_names[verify_label])
print("pred:",label_names[torch.argmax(output_verify,dim=1,keepdim=False)])
  • 1
  • 2
  • 3
airplane
cat
pred: cat
  • 1
  • 2
  • 3

Insérer la description de l'image ici

Vérifier le cheval

Insérer la description de l'image ici

chien de vérification

Insérer la description de l'image ici