12

tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二)

 3 years ago
source link: https://blog.popkx.com/tensorflow-study-use-mnist-data-recognize-handwrite-digit/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二)

发表于 2018-06-20 23:06:00   |   已被 访问: 1,163 次   |   分类于:   tensorflow   |   9 条评论

在学习各种编程语言时,最经典的入门例子就是打印出 "hello world" 了。对于 tensorflow 而言,与之对应地位的入门实战项目就是使用 MNIST数据集 实现手写数字识别了。

MNIST 数据集


1. MNIST 数据集的下载

MNIST 数据集的官方网站:http://yann.lecun.com/exdb/mnist/,非常知名的网站,看着却非常简陋,不过能提供好数据就是好网站。

手动下载,也不慢。也可使用 python 代码直接下载:

#encoding=utf8
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

上面的 input_data 的代码如下,记得文件名为 input_data.py:

"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# pylint: enable=unused-import

执行python脚本,也可以下载,下载在 py 文件同目录的 MNIST_data 文件夹里,如下图:

2. MNIST 数据集简介

MNIST 数据集分为训练(mnist.train)和测试(mnist.test)两部分。每一个数据单元分为图片数据标签两部分,图片数据即为手写数字图片,标签则对应着手写结果。

每一个手写数字是 28X28=784 的图片。所以,mnist.train.images 是一个 [60000, 784] 的张量,60000 表示图片的数量,784 表示每一张图片的数据点数。标签则为 0~9 的数字,用 1 个 10 维向量表示 10 个数字,例如,用 [1,0,0,0,0,0,0,0,0,0,0] 表示 0,用 [0,0,1,0,0,0,0,0,0,0,0] 表示 2,那么,mnist.train.labels 是一个 [60000, 10] 的张量。

手写数字识别


1. 搭建识别模型

这里手写数字识别使用 softmax回归 方法。

已经知道的事实是:手写数字图片 x 对应的标签(正确结果)为 y,那么,我们设置一个系数 w (权重系数),和一个偏置 b,肯定可以满足如下关系:

y = wx + b
使用 softmax 函数,则有
y = softmax(wx+b)

作为入门,暂且不提为啥使用 softmax 函数。如果提升到多维,则有:

那么,x w b y 的 tensorflow 的 python 描述代码可以如下写:

import tensorflow as tf
x = tf.placeholder("float", [None, 784])    # None 表示不关心该纬度
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)

上面的 y 是计算得到的值,下面引入 y_ 作为正确值:

y_ = tf.placeholder("float", [None,10])

引入 y_ 是为了估计计算值与实际值的接近程度,这个接近程度可以用 交叉熵 表示,估计值和实际值越接近,二者的 交叉熵 越小。交叉熵的公式如下:

cross_entropy = -tf.reduce_sum(y_*tf.log(y))    # 计算交叉熵

2. 训练模型

训练方式采取经典的反向传播法,tensorflow 使用一行代码就可以描述

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

即,使用 0.01 的学习因子,使 cross_entropy 尽可能小。

下面就可以初始化,训练了,要在 session 里训练,这点可以参考上一节:几个基本概念:图,会话,feed,fetch

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):           # 训练 1000 次
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # feed 数据

3. 测试模型

训练完成后,当然需要测试其准确性,直接上代码,主要利用了 tf.argmax 函数,它返回给出某个tensor对象在某一维上
的其数据最大值所在的索引值,因为本节使用的索引是特殊的 10 维向量,所以下面的代码应该非常好理解才对。

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})



有了上面的解释,下面的代码应该很好理解:

#encoding=utf8
import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import tensorflow as tf
x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

最终,程序输出的结果是 0.9137,不算高的正确率,下面几节将提高正确率。

阅读更多:   tensorflow


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK