9

百度飞桨(PaddlePaddle)-数字识别 - VipSoft

 1 year ago
source link: https://www.cnblogs.com/vipsoft/p/17359174.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

手写数字识别任务 用于对 0 ~ 9 的十类数字进行分类,即输入手写数字的图片,可识别出这个图片中的数字。

使用 pip 工具安装 matplotlib 和 numpy

python -m pip install matplotlib numpy -i https://mirror.baidu.com/pypi/simple
python -m pip install paddlepaddle==2.4.2 -i https://pypi.tuna.tsinghua.edu.cn/simple

D:\OpenSource\PaddlePaddle>python -m pip install matplotlib numpy -i https://mirror.baidu.com/pypi/simple
Looking in indexes: https://mirror.baidu.com/pypi/simple
Collecting matplotlib
  Downloading https://mirror.baidu.com/pypi/packages/92/01/2c04d328db6955d77f8f60c17068dde8aa66f153b2c599ca03c2cb0d5567/matplotlib-3.7.1-cp38-cp38-win_amd64.whl (7.6 MB)
     |████████████████████████████████| 7.6 MB ...
Requirement already satisfied: numpy in d:\program files\python38\lib\site-packages (1.24.3)
Collecting packaging>=20.0
  Downloading https://mirror.baidu.com/pypi/packages/ab/c3/57f0601a2d4fe15de7a553c00adbc901425661bf048f2a22dfc500caf121/packaging-23.1-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 1.2 MB/s
Collecting cycler>=0.10
  Downloading https://mirror.baidu.com/pypi/packages/5c/f9/695d6bedebd747e5eb0fe8fad57b72fdf25411273a39791cde838d5a8f51/cycler-0.11.0-py3-none-any.whl (6.4 kB)
Requirement already satisfied: pillow>=6.2.0 in d:\program files\python38\lib\site-packages (from matplotlib) (9.5.0)
Collecting python-dateutil>=2.7
  Downloading https://mirror.baidu.com/pypi/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl (247 kB)
     |████████████████████████████████| 247 kB ...
Collecting importlib-resources>=3.2.0
  Downloading https://mirror.baidu.com/pypi/packages/38/71/c13ea695a4393639830bf96baea956538ba7a9d06fcce7cef10bfff20f72/importlib_resources-5.12.0-py3-none-any.whl (36 kB)
Collecting fonttools>=4.22.0
  Downloading https://mirror.baidu.com/pypi/packages/16/07/1c7547e27f559ec078801d522cc4d5127cdd4ef8e831c8ddcd9584668a07/fonttools-4.39.3-py3-none-any.whl (1.0 MB)
     |████████████████████████████████| 1.0 MB ...
Collecting pyparsing>=2.3.1
  Downloading https://mirror.baidu.com/pypi/packages/6c/10/a7d0fa5baea8fe7b50f448ab742f26f52b80bfca85ac2be9d35cdd9a3246/pyparsing-3.0.9-py3-none-any.whl (98 kB)
     |████████████████████████████████| 98 kB 862 kB/s
Collecting contourpy>=1.0.1
  Downloading https://mirror.baidu.com/pypi/packages/08/ce/9bfe9f028cb5a8ee97898da52f4905e0e2d9ca8203ffdcdbe80e1769b549/contourpy-1.0.7-cp38-cp38-win_amd64.whl (162 kB)
     |████████████████████████████████| 162 kB ...
Collecting kiwisolver>=1.0.1
  Downloading https://mirror.baidu.com/pypi/packages/4f/05/59b34e788bf2b45c7157c3d898d567d28bc42986c1b6772fb1af329eea0d/kiwisolver-1.4.4-cp38-cp38-win_amd64.whl (55 kB)
     |████████████████████████████████| 55 kB 784 kB/s
Collecting zipp>=3.1.0
  Downloading https://mirror.baidu.com/pypi/packages/5b/fa/c9e82bbe1af6266adf08afb563905eb87cab83fde00a0a08963510621047/zipp-3.15.0-py3-none-any.whl (6.8 kB)
Requirement already satisfied: six>=1.5 in d:\program files\python38\lib\site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Installing collected packages: zipp, python-dateutil, pyparsing, packaging, kiwisolver, importlib-resources, fonttools, cycler, contourpy, matplotlib
Successfully installed contourpy-1.0.7 cycler-0.11.0 fonttools-4.39.3 importlib-resources-5.12.0 kiwisolver-1.4.4 matplotlib-3.7.1 packaging-23.1 pyparsing-3.0.9 python-dateutil-2.8.2 zipp-3.15.0
WARNING: You are using pip version 21.1.1; however, version 23.1.2 is available.
You should consider upgrading via the 'D:\Program Files\Python38\python.exe -m pip install --upgrade pip' command.

D:\OpenSource\PaddlePaddle>

创建 DigitalRecognition.py

官网代码少了 plt.show() # 要加上这句,才会显示图片

import paddle
import numpy as np
from paddle.vision.transforms import Normalize

transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 下载数据集并初始化 DataSet
'''
飞桨在 paddle.vision.datasets 下内置了计算机视觉(Computer Vision,CV)领域常见的数据集,
如 MNIST、Cifar10、Cifar100、FashionMNIST 和 VOC2012 等。在本任务中,
先后加载了 MNIST 训练集(mode='train')和测试集(mode='test'),训练集用于训练模型,测试集用于评估模型效果。
'''
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
# 打印数据集里图片数量 60000 images in train_dataset, 10000 images in test_dataset
# print('{} images in train_dataset, {} images in test_dataset'.format(len(train_dataset), len(test_dataset)))

# 模型组网并初始化网络
lenet = paddle.vision.models.LeNet(num_classes=10)
model = paddle.Model(lenet)

# 模型训练的配置准备,准备损失函数,优化器和评价指标
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())

# 模型训练
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
# 模型评估
model.evaluate(test_dataset, batch_size=64, verbose=1)

# 保存模型
model.save('./output/mnist')
# 加载模型
model.load('output/mnist')

# 从测试集中取出一张图片
img, label = test_dataset[0]
# 将图片shape从1*28*28变为1*1*28*28,增加一个batch维度,以匹配模型输入格式要求
img_batch = np.expand_dims(img.astype('float32'), axis=0)

# 执行推理并打印结果,此处predict_batch返回的是一个list,取出其中数据获得预测结果
out = model.predict_batch(img_batch)[0]
pred_label = out.argmax()
print('true label: {}, pred label: {}'.format(label[0], pred_label))
# 可视化图片
from matplotlib import pyplot as plt
plt.imshow(img[0])
plt.show()  # 要加上这句,才会显示图片

PyCharm运行(推荐,有错误能显示出来)

Python MatplotlibDeprecationWarning Matplotlib 3.6 and will be removed two minor releases later
File -> Settings -> Tools -> Python Scientific -> 取消 Show plots in tool window,
取消后,将不会看到红字警告提示

image

CMD 运行

D:\OpenSource\PaddlePaddle>python DigitalRecognition.py

image
image

如果碰到下列错误,需要加上 plt.show()
Python MatplotlibDeprecationWarning Matplotlib 3.6 and will be removed two minor releases later

MatplotlibDeprecationWarning: Support for FigureCanvases without a required_interactive_framework attribute was deprecated in Matplotlib 3.6 and will be removed two minor releases later.
  plt.imshow(img[0])

数据集定义与加载

飞桨在 paddle.vision.datasets 下内置了计算机视觉(Computer Vision,CV)领域常见的数据集,如 MNIST、Cifar10、Cifar100、FashionMNIST 和 VOC2012 等。在本任务中,先后加载了 MNIST 训练集(mode='train')和测试集(mode='test'),训练集用于训练模型,测试集用于评估模型效果。
飞桨除了内置了 CV 领域常见的数据集,还在 paddle.text 下内置了自然语言处理(Natural Language Processing,NLP)领域常见的数据集,并提供了自定义数据集与加载功能的 paddle.io.Dataset 和 paddle.io.DataLoader API,详细使用方法可参考『数据集定义与加载』 章节。

另外在 paddle.vision.transforms 下提供了一些常用的图像变换操作,如对图像的翻转、裁剪、调整亮度等处理,可实现数据增强,以增加训练样本的多样性,提升模型的泛化能力。本任务在初始化 MNIST 数据集时通过 transform 字段传入了 Normalize 变换对图像进行归一化,对图像进行归一化可以加快模型训练的收敛速度。该功能的具体使用方法可参考『数据预处理』 章节。

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html#moxingzuwang

模型训练与评估

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html#moxingxunlianyupinggu

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html#moxingtuili

image

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK