对TensorFlow中的variables_to_restore函数详解

yipeiwu_com6年前Python基础

variables_to_restore函数,是TensorFlow为滑动平均值提供。之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮。我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

1、滑动平均值模型文件的保存

import tensorflow as tf
 
if __name__ == "__main__":
 v = tf.Variable(0.,name="v")
 #设置滑动平均模型的系数
 ema = tf.train.ExponentialMovingAverage(0.99)
 #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
 op = ema.apply([v])
 #获取变量v的名字
 print(v.name)
 #v:0
 #创建一个保存模型的对象
 save = tf.train.Saver()
 sess = tf.Session()
 #初始化所有变量
 init = tf.initialize_all_variables()
 sess.run(init)
 #给变量v重新赋值
 sess.run(tf.assign(v,10))
 #应用平均滑动设置
 sess.run(op)
 #保存模型文件
 save.save(sess,"./model.ckpt")
 #输出变量v之前的值和使用滑动平均模型之后的值
 print(sess.run([v,ema.average(v)]))
 #[10.0, 0.099999905]

上面的代码,是如何来保存一个滑动平均值的模型文件,之前有介绍过滑动平均值和模型文件的保存,所以这里就不再重复了。

2、滑动平均值模型文件的读取

 v = tf.Variable(1.,name="v")
 #定义模型对象
 saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
 sess = tf.Session()
 saver.restore(sess,"./model.ckpt")
 print(sess.run(v))
 #0.0999999

对于模型文件的读取,在上一篇博客中有介绍过,这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。是不是感觉使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

3、variables_to_restore函数的使用

 v = tf.Variable(1.,name="v")
 #滑动模型的参数的大小并不会影响v的值
 ema = tf.train.ExponentialMovingAverage(0.99)
 print(ema.variables_to_restore())
 #{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
 sess = tf.Session()
 saver = tf.train.Saver(ema.variables_to_restore())
 saver.restore(sess,"./model.ckpt")
 print(sess.run(v))
 #0.0999999

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

以上这篇对TensorFlow中的variables_to_restore函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python3使用Matplotlib 绘制精美的数学函数图形

Python3使用Matplotlib 绘制精美的数学函数图形

一个最最简单的例子: 绘制一个从 0 到 360 度完整的 SIN 函数图形 import numpy as np import matplotlib.pyplot as pt x...

Python初学者常见错误详解

前言 Python 以其简单易懂的语法格式与其它语言形成鲜明对比,初学者遇到最多的问题就是不按照 Python 的规则来写,即便是有编程经验的程序员,也容易按照固有的思维和语法格式来写...

Django2.1.3 中间件使用详解

Django2.1.3 中间件使用详解

环境 Win10 Python3.6.6 Django2.1.3 中间件作用 中间件用于全局修改Django的输入或输出。 中间件常见用途 缓存 会话认证...

NumPy 基本切片和索引的具体使用方法

索引和切片是NumPy中最重要最常用的操作。熟练使用NumPy切片操作是数据处理和机器学习的前提,所以一定要掌握好。 文档:https://docs.scipy.org/doc/num...

python使用opencv在Windows下调用摄像头实现解析

python使用opencv在Windows下调用摄像头实现解析

这篇文章主要介绍了python使用opencv在Windows下调用摄像头实现解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 环境...