tensorflow 加载部分变量的实例讲解

yipeiwu_com6年前Python基础

tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 init_op = tf.global_variables_initializer()
 sess.run(init_op)
 saver.save(sess,"checkpoint/model_test",global_step=1)

当我们保存模型后,我们可以通过saver.restore()来加载模型,初始化变量:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

神经网络训练时,有时候我们需要从预训练的模型中加载部分参数,初始化当前模型,例如加入CNN有6层,我们需要从已有的模型初始化CNN前5层参数.这可以通过saver.restore()实现.

之前我们已经介绍可以通过tf.train.Saver()的保存部分变量的方法,即需要保存的变量列表,同样的,在变量初始化的时候,我们可以对需要单独初始化的变量分别定义一个tf.train.Saver()函数,这样就可以单独对该部分变量初始化,例如下面代码,saver1用于初始化变量v1,saver2用于初始化变量v2,v3:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
#saver = tf.train.Saver()
saver1 = tf.train.Saver([v1])
saver2 = tf.train.Saver([v2]+[v3])
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver1.restore(sess, "checkpoint/model_test-1")
 saver2.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

以上这篇tensorflow 加载部分变量的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

对Python3中dict.keys()转换成list类型的方法详解

对Python3中dict.keys()转换成list类型的方法详解

在python3中使用dict.keys()返回的不在是list类型了,也不支持索引,我们可以看一下下面这张图片 那么我们应该怎么办呢,其实解决的方法也是非常简单的,只需要使用list...

python找出列表中大于某个阈值的数据段示例

python找出列表中大于某个阈值的数据段示例

该算法实现对列表中大于某个阈值(比如level=5)的连续数据段的提取,具体效果如下: 找出list里面大于5的连续数据段: list = [1,2,3,4,2,3,4,5,6,7,...

使用Python的Flask框架构建大型Web应用程序的结构示例

虽然小型web应用程序用单个脚本可以很方便,但这种方法却不能很好地扩展。随着应用变得复杂,在单个大的源文件中处理会变得问题重重。 与大多数其他web框架不同,Flask对大型项目没有特定...

Python中GeoJson和bokeh-1的使用讲解

Python中GeoJson和bokeh-1的使用讲解

GeoJson 文档 { "type": "FeatureCollection", "features": [ { "geometry": { "type":...

基于python的Tkinter实现一个简易计算器

基于python的Tkinter实现一个简易计算器

本文实例介绍了基于python的Tkinter实现简易计算器的详细代码,分享给大家供大家参考,具体内容如下 第一种:使用python 的 Tkinter实现一个简易计算器 #codi...