2

【pytorch】CNN

 2 years ago
source link: https://www.guofei.site/2022/06/05/torch_cnn.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

【pytorch】CNN

2022年06月05日    Author:Guofei

文章归类: 0x26_torch    文章编号: 262


版权声明:本文作者是郭飞。转载随意,但需要标明原文链接,并通知本人
原文链接:https://www.guofei.site/2022/06/05/torch_cnn.html

Edit

任务分类:

  • 图像检测:检测物体,返回边框
  • 图像分类与图像检索
  • 超分辨率重构

一个MNIST的简单模型

import torch.utils.data
from torchvision import transforms, datasets
from torch import nn

input_size = 28
num_classes = 10
num_epochs = 3
batch_size = 64

train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

# 构建 batch
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)


# %%构建

class MyCNN(nn.Module):
    def __init__(self):
        super(MyCNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1  # 灰度图
                      , out_channels=16
                      , kernel_size=5
                      , stride=1
                      , padding=2  # 如果像保持输出的 size 和原来一样,需要 padding = (kernel_size-1)/2 if stride=1
                      )
            , nn.ReLU()
            , nn.MaxPool2d(kernel_size=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
            , nn.ReLU()
            , nn.MaxPool2d(2)
        )

        # 最后一层是全连接层
        self.out = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # flatten 操作,结果为 batch_size,-1
        output = self.out(x)
        return output


# %% 训练
my_cnn = MyCNN()
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(my_cnn.parameters())


def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1]
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights, len(labels)


for epoch in range(num_epochs):
    # 当前 epoch的结果保存下来
    train_rights = []
    for batch_idx, (data, target) in enumerate(train_loader):

        my_cnn.train()  # 进入训练模式,Dropout 等生效
        output = my_cnn(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        right = accuracy(output, target)
        train_rights.append(right)

        if batch_idx % 100 == 0:
            my_cnn.eval()  # 进入预测模式
            test_rights = []
            for (data, target) in test_loader:
                output = my_cnn(data)
                right = accuracy(output, target)
                test_rights.append(right)

            # 计算准确率
            train_r = (sum(tup[0] for tup in train_rights), sum(tup[1] for tup in train_rights))
            test_r = (sum(tup[0] for tup in test_rights), sum(tup[1] for tup in test_rights))

            print('epoch {} [{}/{} ({:.0f}%)] loss = {:.6f} test_acc={:.4f} train_acc={:.4f}'
                  .format(epoch, batch_idx * batch_size, len(train_loader.dataset), 100 * batch_idx / len(train_loader),
                          loss.data,
                          train_r[0].numpy() / train_r[1],
                          test_r[0].numpy() / test_r[1]
                          ))

您的支持将鼓励我继续创作!
WeChatQR
qr_wechat

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK