1

【TensorFlow6】持久化

 3 years ago
source link: https://www.guofei.site/2018/10/29/tf6.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

【TensorFlow6】持久化

2018年10月29日

Author: Guofei

文章归类: 2-3-神经网络与TF ,文章编号: 286


版权声明:本文作者是郭飞。转载随意,但需要标明原文链接,并通知本人
原文链接:https://www.guofei.site/2018/10/29/tf6.html

Edit

保存模型

saver=tf.train.Saver()
saver.save(sess,'ckpt/mnist.ckpt')

会在ckpt目录中生成4个文件:checkpoint, munist.ckpt.index, mnist.ckpt.meta, mnist.ckpt.data-00000-of-00001

参数

max_to_keep

saver=tf.train.Saver(max_to_keep=3)
saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

max_to_keep 的作用:

  1. 可以多次调用saver.save来存储sess,以保存最近的 max_to_keep 个 sess
  2. max_to_keep 默认为 5. 为None或0时,每次 saver.save 都会保存
  3. 所以写在迭代中,外加一些代码,可以保存迭代过程中,最好的 max_to_keep 个 sess

恢复模型

import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph('ckpt/mnist.ckpt.meta')
saver.restore(sess,'ckpt/mnist.ckpt')

# 取 tensor
x=tf.get_default_graph().get_tensor_by_name('x:0')
keep_prob=tf.get_default_graph().get_tensor_by_name('keep_prob:0')
y_hat=tf.get_default_graph().get_tensor_by_name('Softmax:0')

# 然后模型就可以运行了
Y_test_predict=sess.run(y_hat,feed_dict={x:X_test,keep_prob:1})

您的支持将鼓励我继续创作!

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK