1
【TensorFlow6】持久化
source link: https://www.guofei.site/2018/10/29/tf6.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
【TensorFlow6】持久化
2018年10月29日Author: Guofei
文章归类: 2-3-神经网络与TF ,文章编号: 286
版权声明:本文作者是郭飞。转载随意,但需要标明原文链接,并通知本人
原文链接:https://www.guofei.site/2018/10/29/tf6.html
保存模型
saver=tf.train.Saver()
saver.save(sess,'ckpt/mnist.ckpt')
会在ckpt目录中生成4个文件:checkpoint, munist.ckpt.index, mnist.ckpt.meta, mnist.ckpt.data-00000-of-00001
参数
max_to_keep
saver=tf.train.Saver(max_to_keep=3)
saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
max_to_keep 的作用:
- 可以多次调用
saver.save
来存储sess,以保存最近的max_to_keep
个 sess - max_to_keep 默认为 5. 为None或0时,每次
saver.save
都会保存 - 所以写在迭代中,外加一些代码,可以保存迭代过程中,最好的
max_to_keep
个 sess
恢复模型
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph('ckpt/mnist.ckpt.meta')
saver.restore(sess,'ckpt/mnist.ckpt')
# 取 tensor
x=tf.get_default_graph().get_tensor_by_name('x:0')
keep_prob=tf.get_default_graph().get_tensor_by_name('keep_prob:0')
y_hat=tf.get_default_graph().get_tensor_by_name('Softmax:0')
# 然后模型就可以运行了
Y_test_predict=sess.run(y_hat,feed_dict={x:X_test,keep_prob:1})
您的支持将鼓励我继续创作!
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK