tensorflow 加载部分变量的实例讲解

yipeiwu_com5年前Python基础

tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 init_op = tf.global_variables_initializer()
 sess.run(init_op)
 saver.save(sess,"checkpoint/model_test",global_step=1)

当我们保存模型后,我们可以通过saver.restore()来加载模型,初始化变量:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

神经网络训练时,有时候我们需要从预训练的模型中加载部分参数,初始化当前模型,例如加入CNN有6层,我们需要从已有的模型初始化CNN前5层参数.这可以通过saver.restore()实现.

之前我们已经介绍可以通过tf.train.Saver()的保存部分变量的方法,即需要保存的变量列表,同样的,在变量初始化的时候,我们可以对需要单独初始化的变量分别定义一个tf.train.Saver()函数,这样就可以单独对该部分变量初始化,例如下面代码,saver1用于初始化变量v1,saver2用于初始化变量v2,v3:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
#saver = tf.train.Saver()
saver1 = tf.train.Saver([v1])
saver2 = tf.train.Saver([v2]+[v3])
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver1.restore(sess, "checkpoint/model_test-1")
 saver2.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

以上这篇tensorflow 加载部分变量的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Pyhthon中使用compileall模块编译源文件为pyc文件

有的时候我们需要把项目中.py的python所有源文件编译成.pyc文件,只保留.pyc文件然后发布给别人(虽然说可以反编译,但也算是一种保护把). 这个时候就可以使用compileal...

解决Python获取字典dict中不存在的值时出错问题

描述:Python2.7中如果想要获取字典中的一个值,但是这个值可能不存在,此时应该加上判断: 举个例子: t= {} if t.get('1'): # right:这种通过key来...

python如何重载模块实例解析

本文首先介绍了Python中的模块的概念,谈到了一个模块往往由多个模块组成,然后通过具体实例,分析了模块重载的相关内容,具体介绍如下。 模块是Python程序架构的一个核心概念,较大的程...

Python实现Youku视频批量下载功能

Python实现Youku视频批量下载功能

前段时间由于收集视频数据的需要,自己捣鼓了一个YouKu视频批量下载的程序。东西虽然简单,但还挺实用的,拿出来分享给大家。   版本:Python2.7+BeautifulSoup3.2...

Python实现删除当前目录下除当前脚本以外的文件和文件夹实例

本文实例讲述了Python实现删除当前目录下除当前脚本以外的文件和文件夹。分享给大家供大家参考。具体如下: import os,sys import shutil cur_file...