4

PyTorch 模型保存与再训练

 1 year ago
source link: https://xujinzh.github.io/2023/06/05/pytorch-model-save-and-retrain/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

PyTorch 模型保存与再训练

发表于2023-06-05|更新于2023-06-05|technologypython
字数总计:4k|阅读时长:24分钟|阅读量:4

PyTorch 模型保存与再训练,基于 MNIST 数据集。

导入依赖包

import os

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
n_epochs = 10
batch_size_train = 128
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 20
torch.manual_seed(33)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device
device(type='cuda', index=1)
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
)
train_data = torchvision.datasets.MNIST(
"/workspace/disk1/datasets",
train=True,
download=True,
transform=transform,
)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=batch_size_train,
shuffle=True,
)

test_data = torchvision.datasets.MNIST(
"/workspace/disk1/datasets/",
train=False,
download=True,
transform=transform,
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=batch_size_test,
shuffle=True,
)
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
# print(example_targets)
print(example_data.shape)
torch.Size([1000, 1, 28, 28])
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap="gray", interpolation="none")
plt.title(f"Ground Truth: {example_targets[i]}")
plt.xticks([])
plt.yticks([])
plt.show()

png

构建模型和优化算法

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = F.log_softmax(self.fc2(x), dim=1)
return x
network = Net().to(device)
optimizer = optim.SGD(
network.parameters(),
lr=learning_rate,
momentum=momentum,
)

模型训练与保存

train_losses = []
train_counter = []
test_losses = []
test_counter = [i * len(train_loader.dataset) for i in range(n_epochs)]
def train(epoch):
network.train()
for batch_idx, (data, target) in enumerate(train_loader):
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
output = network(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
f"Train Epoch: {epoch} [{str(batch_idx * len(data)).zfill(5)}/{len(train_loader.dataset)} ({100. * batch_idx/ len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
)
train_losses.append(loss.item())
train_counter.append(
(batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset))
)

# 保存当前批次的模型
model_checkpoint = f"/workspace/disk1/datasets/models/mnist/model_epoch{epoch}.pth"
optimizer_checkpoint = (
f"/workspace/disk1/datasets/models/mnist/optimizer_epoch{epoch}.pth"
)
latest_model = "/workspace/disk1/datasets/models/mnist/model_latest.pth"
latest_optimizer = "/workspace/disk1/datasets/models/mnist/optimizer_latest.pth"

torch.save(network.state_dict(), model_checkpoint)
torch.save(optimizer.state_dict(), optimizer_checkpoint)

if os.path.exists(latest_model):
os.remove(latest_model)
if os.path.exists(latest_optimizer):
os.remove(latest_optimizer)
os.symlink(model_checkpoint, latest_model)
os.symlink(optimizer_checkpoint, latest_optimizer)
def test():
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data = data.to(device)
target = target.to(device)
output = network(data)
# test_loss += F.nll_loss(output, target, size_average=False).item()
test_loss += F.nll_loss(output, target, reduction="sum").item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
print(
f"\nTest set: Avg. loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.*correct/len(test_loader.dataset):.0f}%)\n"
)
for epoch in range(1, n_epochs + 1):
train(epoch)
test()
Train Epoch: 1 [00000/60000 (0%)]	Loss: 2.288714
Train Epoch: 1 [02560/60000 (4%)]	Loss: 2.255195
Train Epoch: 1 [05120/60000 (9%)]	Loss: 2.198278
Train Epoch: 1 [07680/60000 (13%)]	Loss: 2.091046
Train Epoch: 1 [10240/60000 (17%)]	Loss: 1.832017
Train Epoch: 1 [12800/60000 (21%)]	Loss: 1.800682
Train Epoch: 1 [15360/60000 (26%)]	Loss: 1.533700
Train Epoch: 1 [17920/60000 (30%)]	Loss: 1.332488
Train Epoch: 1 [20480/60000 (34%)]	Loss: 1.289864
Train Epoch: 1 [23040/60000 (38%)]	Loss: 1.195423
Train Epoch: 1 [25600/60000 (43%)]	Loss: 1.075095
Train Epoch: 1 [28160/60000 (47%)]	Loss: 1.004397
Train Epoch: 1 [30720/60000 (51%)]	Loss: 0.832168
Train Epoch: 1 [33280/60000 (55%)]	Loss: 0.849826
Train Epoch: 1 [35840/60000 (60%)]	Loss: 0.968393
Train Epoch: 1 [38400/60000 (64%)]	Loss: 0.866699
Train Epoch: 1 [40960/60000 (68%)]	Loss: 0.721661
Train Epoch: 1 [43520/60000 (72%)]	Loss: 0.962982
Train Epoch: 1 [46080/60000 (77%)]	Loss: 0.695050
Train Epoch: 1 [48640/60000 (81%)]	Loss: 0.799360
Train Epoch: 1 [51200/60000 (85%)]	Loss: 0.683803
Train Epoch: 1 [53760/60000 (90%)]	Loss: 0.561625
Train Epoch: 1 [56320/60000 (94%)]	Loss: 0.579419
Train Epoch: 1 [58880/60000 (98%)]	Loss: 0.568427

Test set: Avg. loss: 0.3283, Accuracy: 9096/10000 (91%)

Train Epoch: 2 [00000/60000 (0%)]	Loss: 0.781268
Train Epoch: 2 [02560/60000 (4%)]	Loss: 0.679616
Train Epoch: 2 [05120/60000 (9%)]	Loss: 0.572369
Train Epoch: 2 [07680/60000 (13%)]	Loss: 0.638919
Train Epoch: 2 [10240/60000 (17%)]	Loss: 0.721595
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.723331
Train Epoch: 2 [15360/60000 (26%)]	Loss: 0.604582
Train Epoch: 2 [17920/60000 (30%)]	Loss: 0.689362
Train Epoch: 2 [20480/60000 (34%)]	Loss: 0.548551
Train Epoch: 2 [23040/60000 (38%)]	Loss: 0.650297
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.506893
Train Epoch: 2 [28160/60000 (47%)]	Loss: 0.581708
Train Epoch: 2 [30720/60000 (51%)]	Loss: 0.557465
Train Epoch: 2 [33280/60000 (55%)]	Loss: 0.461637
Train Epoch: 2 [35840/60000 (60%)]	Loss: 0.619341
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.464800
Train Epoch: 2 [40960/60000 (68%)]	Loss: 0.473921
Train Epoch: 2 [43520/60000 (72%)]	Loss: 0.567576
Train Epoch: 2 [46080/60000 (77%)]	Loss: 0.447070
Train Epoch: 2 [48640/60000 (81%)]	Loss: 0.503476
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.500809
Train Epoch: 2 [53760/60000 (90%)]	Loss: 0.553329
Train Epoch: 2 [56320/60000 (94%)]	Loss: 0.504529
Train Epoch: 2 [58880/60000 (98%)]	Loss: 0.457889

Test set: Avg. loss: 0.2128, Accuracy: 9357/10000 (94%)

Train Epoch: 3 [00000/60000 (0%)]	Loss: 0.369795
Train Epoch: 3 [02560/60000 (4%)]	Loss: 0.414531
Train Epoch: 3 [05120/60000 (9%)]	Loss: 0.604378
Train Epoch: 3 [07680/60000 (13%)]	Loss: 0.426111
Train Epoch: 3 [10240/60000 (17%)]	Loss: 0.492895
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.393350
Train Epoch: 3 [15360/60000 (26%)]	Loss: 0.555914
Train Epoch: 3 [17920/60000 (30%)]	Loss: 0.476940
Train Epoch: 3 [20480/60000 (34%)]	Loss: 0.430539
Train Epoch: 3 [23040/60000 (38%)]	Loss: 0.571562
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.488061
Train Epoch: 3 [28160/60000 (47%)]	Loss: 0.598932
Train Epoch: 3 [30720/60000 (51%)]	Loss: 0.417797
Train Epoch: 3 [33280/60000 (55%)]	Loss: 0.486182
Train Epoch: 3 [35840/60000 (60%)]	Loss: 0.375228
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.384777
Train Epoch: 3 [40960/60000 (68%)]	Loss: 0.428213
Train Epoch: 3 [43520/60000 (72%)]	Loss: 0.443456
Train Epoch: 3 [46080/60000 (77%)]	Loss: 0.308731
Train Epoch: 3 [48640/60000 (81%)]	Loss: 0.347535
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.439240
Train Epoch: 3 [53760/60000 (90%)]	Loss: 0.460515
Train Epoch: 3 [56320/60000 (94%)]	Loss: 0.521822
Train Epoch: 3 [58880/60000 (98%)]	Loss: 0.451665

Test set: Avg. loss: 0.1591, Accuracy: 9513/10000 (95%)

Train Epoch: 4 [00000/60000 (0%)]	Loss: 0.510825
Train Epoch: 4 [02560/60000 (4%)]	Loss: 0.279729
Train Epoch: 4 [05120/60000 (9%)]	Loss: 0.333447
Train Epoch: 4 [07680/60000 (13%)]	Loss: 0.434208
Train Epoch: 4 [10240/60000 (17%)]	Loss: 0.516646
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.339009
Train Epoch: 4 [15360/60000 (26%)]	Loss: 0.342047
Train Epoch: 4 [17920/60000 (30%)]	Loss: 0.315687
Train Epoch: 4 [20480/60000 (34%)]	Loss: 0.365422
Train Epoch: 4 [23040/60000 (38%)]	Loss: 0.408456
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.512156
Train Epoch: 4 [28160/60000 (47%)]	Loss: 0.246768
Train Epoch: 4 [30720/60000 (51%)]	Loss: 0.254370
Train Epoch: 4 [33280/60000 (55%)]	Loss: 0.403202
Train Epoch: 4 [35840/60000 (60%)]	Loss: 0.364687
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.313392
Train Epoch: 4 [40960/60000 (68%)]	Loss: 0.256359
Train Epoch: 4 [43520/60000 (72%)]	Loss: 0.306669
Train Epoch: 4 [46080/60000 (77%)]	Loss: 0.459862
Train Epoch: 4 [48640/60000 (81%)]	Loss: 0.227380
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.368363
Train Epoch: 4 [53760/60000 (90%)]	Loss: 0.329823
Train Epoch: 4 [56320/60000 (94%)]	Loss: 0.304764
Train Epoch: 4 [58880/60000 (98%)]	Loss: 0.288302

Test set: Avg. loss: 0.1322, Accuracy: 9596/10000 (96%)

Train Epoch: 5 [00000/60000 (0%)]	Loss: 0.247250
Train Epoch: 5 [02560/60000 (4%)]	Loss: 0.383674
Train Epoch: 5 [05120/60000 (9%)]	Loss: 0.315615
Train Epoch: 5 [07680/60000 (13%)]	Loss: 0.259549
Train Epoch: 5 [10240/60000 (17%)]	Loss: 0.322165
Train Epoch: 5 [12800/60000 (21%)]	Loss: 0.474284
Train Epoch: 5 [15360/60000 (26%)]	Loss: 0.254566
Train Epoch: 5 [17920/60000 (30%)]	Loss: 0.466621
Train Epoch: 5 [20480/60000 (34%)]	Loss: 0.320438
Train Epoch: 5 [23040/60000 (38%)]	Loss: 0.316798
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.192171
Train Epoch: 5 [28160/60000 (47%)]	Loss: 0.478387
Train Epoch: 5 [30720/60000 (51%)]	Loss: 0.346078
Train Epoch: 5 [33280/60000 (55%)]	Loss: 0.315769
Train Epoch: 5 [35840/60000 (60%)]	Loss: 0.284116
Train Epoch: 5 [38400/60000 (64%)]	Loss: 0.354763
Train Epoch: 5 [40960/60000 (68%)]	Loss: 0.369070
Train Epoch: 5 [43520/60000 (72%)]	Loss: 0.233857
Train Epoch: 5 [46080/60000 (77%)]	Loss: 0.222160
Train Epoch: 5 [48640/60000 (81%)]	Loss: 0.325135
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.226506
Train Epoch: 5 [53760/60000 (90%)]	Loss: 0.407618
Train Epoch: 5 [56320/60000 (94%)]	Loss: 0.359771
Train Epoch: 5 [58880/60000 (98%)]	Loss: 0.290950

Test set: Avg. loss: 0.1152, Accuracy: 9643/10000 (96%)

Train Epoch: 6 [00000/60000 (0%)]	Loss: 0.371519
Train Epoch: 6 [02560/60000 (4%)]	Loss: 0.294658
Train Epoch: 6 [05120/60000 (9%)]	Loss: 0.195041
Train Epoch: 6 [07680/60000 (13%)]	Loss: 0.247292
Train Epoch: 6 [10240/60000 (17%)]	Loss: 0.307761
Train Epoch: 6 [12800/60000 (21%)]	Loss: 0.183960
Train Epoch: 6 [15360/60000 (26%)]	Loss: 0.255224
Train Epoch: 6 [17920/60000 (30%)]	Loss: 0.564046
Train Epoch: 6 [20480/60000 (34%)]	Loss: 0.217146
Train Epoch: 6 [23040/60000 (38%)]	Loss: 0.364980
Train Epoch: 6 [25600/60000 (43%)]	Loss: 0.237876
Train Epoch: 6 [28160/60000 (47%)]	Loss: 0.344803
Train Epoch: 6 [30720/60000 (51%)]	Loss: 0.347686
Train Epoch: 6 [33280/60000 (55%)]	Loss: 0.197488
Train Epoch: 6 [35840/60000 (60%)]	Loss: 0.346718
Train Epoch: 6 [38400/60000 (64%)]	Loss: 0.256105
Train Epoch: 6 [40960/60000 (68%)]	Loss: 0.211900
Train Epoch: 6 [43520/60000 (72%)]	Loss: 0.264353
Train Epoch: 6 [46080/60000 (77%)]	Loss: 0.339571
Train Epoch: 6 [48640/60000 (81%)]	Loss: 0.198715
Train Epoch: 6 [51200/60000 (85%)]	Loss: 0.335813
Train Epoch: 6 [53760/60000 (90%)]	Loss: 0.244630
Train Epoch: 6 [56320/60000 (94%)]	Loss: 0.260668
Train Epoch: 6 [58880/60000 (98%)]	Loss: 0.281284

Test set: Avg. loss: 0.1029, Accuracy: 9682/10000 (97%)

Train Epoch: 7 [00000/60000 (0%)]	Loss: 0.259941
Train Epoch: 7 [02560/60000 (4%)]	Loss: 0.353288
Train Epoch: 7 [05120/60000 (9%)]	Loss: 0.345746
Train Epoch: 7 [07680/60000 (13%)]	Loss: 0.263094
Train Epoch: 7 [10240/60000 (17%)]	Loss: 0.370562
Train Epoch: 7 [12800/60000 (21%)]	Loss: 0.184917
Train Epoch: 7 [15360/60000 (26%)]	Loss: 0.358648
Train Epoch: 7 [17920/60000 (30%)]	Loss: 0.358313
Train Epoch: 7 [20480/60000 (34%)]	Loss: 0.455060
Train Epoch: 7 [23040/60000 (38%)]	Loss: 0.157829
Train Epoch: 7 [25600/60000 (43%)]	Loss: 0.255777
Train Epoch: 7 [28160/60000 (47%)]	Loss: 0.296378
Train Epoch: 7 [30720/60000 (51%)]	Loss: 0.220109
Train Epoch: 7 [33280/60000 (55%)]	Loss: 0.207805
Train Epoch: 7 [35840/60000 (60%)]	Loss: 0.333757
Train Epoch: 7 [38400/60000 (64%)]	Loss: 0.351853
Train Epoch: 7 [40960/60000 (68%)]	Loss: 0.225360
Train Epoch: 7 [43520/60000 (72%)]	Loss: 0.220420
Train Epoch: 7 [46080/60000 (77%)]	Loss: 0.281292
Train Epoch: 7 [48640/60000 (81%)]	Loss: 0.224555
Train Epoch: 7 [51200/60000 (85%)]	Loss: 0.300659
Train Epoch: 7 [53760/60000 (90%)]	Loss: 0.155560
Train Epoch: 7 [56320/60000 (94%)]	Loss: 0.322972
Train Epoch: 7 [58880/60000 (98%)]	Loss: 0.189427

Test set: Avg. loss: 0.0973, Accuracy: 9693/10000 (97%)

Train Epoch: 8 [00000/60000 (0%)]	Loss: 0.237282
Train Epoch: 8 [02560/60000 (4%)]	Loss: 0.205023
Train Epoch: 8 [05120/60000 (9%)]	Loss: 0.178199
Train Epoch: 8 [07680/60000 (13%)]	Loss: 0.335434
Train Epoch: 8 [10240/60000 (17%)]	Loss: 0.274767
Train Epoch: 8 [12800/60000 (21%)]	Loss: 0.170677
Train Epoch: 8 [15360/60000 (26%)]	Loss: 0.401550
Train Epoch: 8 [17920/60000 (30%)]	Loss: 0.360215
Train Epoch: 8 [20480/60000 (34%)]	Loss: 0.329762
Train Epoch: 8 [23040/60000 (38%)]	Loss: 0.311394
Train Epoch: 8 [25600/60000 (43%)]	Loss: 0.187782
Train Epoch: 8 [28160/60000 (47%)]	Loss: 0.242610
Train Epoch: 8 [30720/60000 (51%)]	Loss: 0.327457
Train Epoch: 8 [33280/60000 (55%)]	Loss: 0.249692
Train Epoch: 8 [35840/60000 (60%)]	Loss: 0.325073
Train Epoch: 8 [38400/60000 (64%)]	Loss: 0.210571
Train Epoch: 8 [40960/60000 (68%)]	Loss: 0.244188
Train Epoch: 8 [43520/60000 (72%)]	Loss: 0.258391
Train Epoch: 8 [46080/60000 (77%)]	Loss: 0.228774
Train Epoch: 8 [48640/60000 (81%)]	Loss: 0.253648
Train Epoch: 8 [51200/60000 (85%)]	Loss: 0.387448
Train Epoch: 8 [53760/60000 (90%)]	Loss: 0.223092
Train Epoch: 8 [56320/60000 (94%)]	Loss: 0.155341
Train Epoch: 8 [58880/60000 (98%)]	Loss: 0.233621

Test set: Avg. loss: 0.0875, Accuracy: 9730/10000 (97%)

Train Epoch: 9 [00000/60000 (0%)]	Loss: 0.356232
Train Epoch: 9 [02560/60000 (4%)]	Loss: 0.228741
Train Epoch: 9 [05120/60000 (9%)]	Loss: 0.237508
Train Epoch: 9 [07680/60000 (13%)]	Loss: 0.200507
Train Epoch: 9 [10240/60000 (17%)]	Loss: 0.180248
Train Epoch: 9 [12800/60000 (21%)]	Loss: 0.316860
Train Epoch: 9 [15360/60000 (26%)]	Loss: 0.251897
Train Epoch: 9 [17920/60000 (30%)]	Loss: 0.341667
Train Epoch: 9 [20480/60000 (34%)]	Loss: 0.256542
Train Epoch: 9 [23040/60000 (38%)]	Loss: 0.272902
Train Epoch: 9 [25600/60000 (43%)]	Loss: 0.207665
Train Epoch: 9 [28160/60000 (47%)]	Loss: 0.134075
Train Epoch: 9 [30720/60000 (51%)]	Loss: 0.198930
Train Epoch: 9 [33280/60000 (55%)]	Loss: 0.218460
Train Epoch: 9 [35840/60000 (60%)]	Loss: 0.372338
Train Epoch: 9 [38400/60000 (64%)]	Loss: 0.207677
Train Epoch: 9 [40960/60000 (68%)]	Loss: 0.274011
Train Epoch: 9 [43520/60000 (72%)]	Loss: 0.158897
Train Epoch: 9 [46080/60000 (77%)]	Loss: 0.236698
Train Epoch: 9 [48640/60000 (81%)]	Loss: 0.238054
Train Epoch: 9 [51200/60000 (85%)]	Loss: 0.264504
Train Epoch: 9 [53760/60000 (90%)]	Loss: 0.246547
Train Epoch: 9 [56320/60000 (94%)]	Loss: 0.181960
Train Epoch: 9 [58880/60000 (98%)]	Loss: 0.182172

Test set: Avg. loss: 0.0810, Accuracy: 9751/10000 (98%)

Train Epoch: 10 [00000/60000 (0%)]	Loss: 0.223450
Train Epoch: 10 [02560/60000 (4%)]	Loss: 0.146772
Train Epoch: 10 [05120/60000 (9%)]	Loss: 0.448847
Train Epoch: 10 [07680/60000 (13%)]	Loss: 0.192779
Train Epoch: 10 [10240/60000 (17%)]	Loss: 0.180810
Train Epoch: 10 [12800/60000 (21%)]	Loss: 0.256196
Train Epoch: 10 [15360/60000 (26%)]	Loss: 0.248770
Train Epoch: 10 [17920/60000 (30%)]	Loss: 0.243000
Train Epoch: 10 [20480/60000 (34%)]	Loss: 0.274350
Train Epoch: 10 [23040/60000 (38%)]	Loss: 0.256088
Train Epoch: 10 [25600/60000 (43%)]	Loss: 0.326276
Train Epoch: 10 [28160/60000 (47%)]	Loss: 0.192222
Train Epoch: 10 [30720/60000 (51%)]	Loss: 0.163986
Train Epoch: 10 [33280/60000 (55%)]	Loss: 0.261450
Train Epoch: 10 [35840/60000 (60%)]	Loss: 0.263229
Train Epoch: 10 [38400/60000 (64%)]	Loss: 0.278172
Train Epoch: 10 [40960/60000 (68%)]	Loss: 0.245064
Train Epoch: 10 [43520/60000 (72%)]	Loss: 0.298991
Train Epoch: 10 [46080/60000 (77%)]	Loss: 0.334879
Train Epoch: 10 [48640/60000 (81%)]	Loss: 0.306450
Train Epoch: 10 [51200/60000 (85%)]	Loss: 0.208277
Train Epoch: 10 [53760/60000 (90%)]	Loss: 0.211731
Train Epoch: 10 [56320/60000 (94%)]	Loss: 0.186170
Train Epoch: 10 [58880/60000 (98%)]	Loss: 0.219588

Test set: Avg. loss: 0.0760, Accuracy: 9754/10000 (98%)
fig = plt.figure()
plt.plot(train_counter, train_losses, color="blue")
plt.scatter(test_counter, test_losses, color="red")
plt.legend(["Train Loss", "Test Loss"], loc="upper right")
plt.xlabel("number of training examples seen")
plt.ylabel("negative log likelihood loss")
plt.show()

png
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
example_data = example_data.to(device)
output = network(example_data)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(
example_data.detach().to("cpu").numpy()[i][0], cmap="gray", interpolation="none"
)
plt.title(f"prediction: {output.data.max(1, keepdim=True)[1][i].item()}")
plt.xticks([])
plt.yticks([])
plt.show()

png

模型再训练

continued_network = Net()
continued_optimizer = optim.SGD(
network.parameters(),
lr=learning_rate,
momentum=momentum,
)
latest_model = "/workspace/disk1/datasets/models/mnist/model_latest.pth"
latest_optimizer = "/workspace/disk1/datasets/models/mnist/optimizer_latest.pth"

network_state_dict = torch.load(latest_model)
continued_network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(latest_optimizer)
continued_optimizer.load_state_dict(optimizer_state_dict)
m = 5
for i in range(n_epochs + 1, n_epochs + m):
test_counter.append(i * len(train_loader.dataset))
train(i)
test()
Train Epoch: 11 [00000/60000 (0%)]	Loss: 0.208258
Train Epoch: 11 [02560/60000 (4%)]	Loss: 0.274365
Train Epoch: 11 [05120/60000 (9%)]	Loss: 0.438731
Train Epoch: 11 [07680/60000 (13%)]	Loss: 0.192449
Train Epoch: 11 [10240/60000 (17%)]	Loss: 0.333653
Train Epoch: 11 [12800/60000 (21%)]	Loss: 0.243886
Train Epoch: 11 [15360/60000 (26%)]	Loss: 0.403167
Train Epoch: 11 [17920/60000 (30%)]	Loss: 0.287619
Train Epoch: 11 [20480/60000 (34%)]	Loss: 0.169019
Train Epoch: 11 [23040/60000 (38%)]	Loss: 0.212516
Train Epoch: 11 [25600/60000 (43%)]	Loss: 0.202806
Train Epoch: 11 [28160/60000 (47%)]	Loss: 0.217581
Train Epoch: 11 [30720/60000 (51%)]	Loss: 0.230557
Train Epoch: 11 [33280/60000 (55%)]	Loss: 0.378703
Train Epoch: 11 [35840/60000 (60%)]	Loss: 0.189265
Train Epoch: 11 [38400/60000 (64%)]	Loss: 0.302335
Train Epoch: 11 [40960/60000 (68%)]	Loss: 0.132773
Train Epoch: 11 [43520/60000 (72%)]	Loss: 0.222291
Train Epoch: 11 [46080/60000 (77%)]	Loss: 0.321381
Train Epoch: 11 [48640/60000 (81%)]	Loss: 0.150813
Train Epoch: 11 [51200/60000 (85%)]	Loss: 0.249176
Train Epoch: 11 [53760/60000 (90%)]	Loss: 0.245743
Train Epoch: 11 [56320/60000 (94%)]	Loss: 0.212701
Train Epoch: 11 [58880/60000 (98%)]	Loss: 0.363840

Test set: Avg. loss: 0.0722, Accuracy: 9776/10000 (98%)

Train Epoch: 12 [00000/60000 (0%)]	Loss: 0.189078
Train Epoch: 12 [02560/60000 (4%)]	Loss: 0.184256
Train Epoch: 12 [05120/60000 (9%)]	Loss: 0.143418
Train Epoch: 12 [07680/60000 (13%)]	Loss: 0.161733
Train Epoch: 12 [10240/60000 (17%)]	Loss: 0.238975
Train Epoch: 12 [12800/60000 (21%)]	Loss: 0.200676
Train Epoch: 12 [15360/60000 (26%)]	Loss: 0.163866
Train Epoch: 12 [17920/60000 (30%)]	Loss: 0.337974
Train Epoch: 12 [20480/60000 (34%)]	Loss: 0.160897
Train Epoch: 12 [23040/60000 (38%)]	Loss: 0.156017
Train Epoch: 12 [25600/60000 (43%)]	Loss: 0.188498
Train Epoch: 12 [28160/60000 (47%)]	Loss: 0.272446
Train Epoch: 12 [30720/60000 (51%)]	Loss: 0.124439
Train Epoch: 12 [33280/60000 (55%)]	Loss: 0.131949
Train Epoch: 12 [35840/60000 (60%)]	Loss: 0.293010
Train Epoch: 12 [38400/60000 (64%)]	Loss: 0.187551
Train Epoch: 12 [40960/60000 (68%)]	Loss: 0.181151
Train Epoch: 12 [43520/60000 (72%)]	Loss: 0.270526
Train Epoch: 12 [46080/60000 (77%)]	Loss: 0.131309
Train Epoch: 12 [48640/60000 (81%)]	Loss: 0.261624
Train Epoch: 12 [51200/60000 (85%)]	Loss: 0.239715
Train Epoch: 12 [53760/60000 (90%)]	Loss: 0.163549
Train Epoch: 12 [56320/60000 (94%)]	Loss: 0.160421
Train Epoch: 12 [58880/60000 (98%)]	Loss: 0.160318

Test set: Avg. loss: 0.0687, Accuracy: 9787/10000 (98%)

Train Epoch: 13 [00000/60000 (0%)]	Loss: 0.140241
Train Epoch: 13 [02560/60000 (4%)]	Loss: 0.144069
Train Epoch: 13 [05120/60000 (9%)]	Loss: 0.323135
Train Epoch: 13 [07680/60000 (13%)]	Loss: 0.336287
Train Epoch: 13 [10240/60000 (17%)]	Loss: 0.107315
Train Epoch: 13 [12800/60000 (21%)]	Loss: 0.169032
Train Epoch: 13 [15360/60000 (26%)]	Loss: 0.162337
Train Epoch: 13 [17920/60000 (30%)]	Loss: 0.253107
Train Epoch: 13 [20480/60000 (34%)]	Loss: 0.166370
Train Epoch: 13 [23040/60000 (38%)]	Loss: 0.243374
Train Epoch: 13 [25600/60000 (43%)]	Loss: 0.160263
Train Epoch: 13 [28160/60000 (47%)]	Loss: 0.187129
Train Epoch: 13 [30720/60000 (51%)]	Loss: 0.348670
Train Epoch: 13 [33280/60000 (55%)]	Loss: 0.166424
Train Epoch: 13 [35840/60000 (60%)]	Loss: 0.184487
Train Epoch: 13 [38400/60000 (64%)]	Loss: 0.159097
Train Epoch: 13 [40960/60000 (68%)]	Loss: 0.110388
Train Epoch: 13 [43520/60000 (72%)]	Loss: 0.114675
Train Epoch: 13 [46080/60000 (77%)]	Loss: 0.193499
Train Epoch: 13 [48640/60000 (81%)]	Loss: 0.256665
Train Epoch: 13 [51200/60000 (85%)]	Loss: 0.204359
Train Epoch: 13 [53760/60000 (90%)]	Loss: 0.228794
Train Epoch: 13 [56320/60000 (94%)]	Loss: 0.229143
Train Epoch: 13 [58880/60000 (98%)]	Loss: 0.198778

Test set: Avg. loss: 0.0659, Accuracy: 9792/10000 (98%)

Train Epoch: 14 [00000/60000 (0%)]	Loss: 0.154132
Train Epoch: 14 [02560/60000 (4%)]	Loss: 0.174841
Train Epoch: 14 [05120/60000 (9%)]	Loss: 0.131765
Train Epoch: 14 [07680/60000 (13%)]	Loss: 0.163187
Train Epoch: 14 [10240/60000 (17%)]	Loss: 0.130205
Train Epoch: 14 [12800/60000 (21%)]	Loss: 0.230511
Train Epoch: 14 [15360/60000 (26%)]	Loss: 0.206032
Train Epoch: 14 [17920/60000 (30%)]	Loss: 0.209682
Train Epoch: 14 [20480/60000 (34%)]	Loss: 0.143732
Train Epoch: 14 [23040/60000 (38%)]	Loss: 0.247467
Train Epoch: 14 [25600/60000 (43%)]	Loss: 0.141316
Train Epoch: 14 [28160/60000 (47%)]	Loss: 0.156982
Train Epoch: 14 [30720/60000 (51%)]	Loss: 0.249250
Train Epoch: 14 [33280/60000 (55%)]	Loss: 0.252457
Train Epoch: 14 [35840/60000 (60%)]	Loss: 0.137284
Train Epoch: 14 [38400/60000 (64%)]	Loss: 0.212023
Train Epoch: 14 [40960/60000 (68%)]	Loss: 0.227320
Train Epoch: 14 [43520/60000 (72%)]	Loss: 0.200754
Train Epoch: 14 [46080/60000 (77%)]	Loss: 0.197454
Train Epoch: 14 [48640/60000 (81%)]	Loss: 0.200271
Train Epoch: 14 [51200/60000 (85%)]	Loss: 0.135254
Train Epoch: 14 [53760/60000 (90%)]	Loss: 0.137874
Train Epoch: 14 [56320/60000 (94%)]	Loss: 0.213711
Train Epoch: 14 [58880/60000 (98%)]	Loss: 0.286362

Test set: Avg. loss: 0.0631, Accuracy: 9794/10000 (98%)
  1. 用PyTorch实现MNIST手写数字识别(非常详细)

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK