tensorflow 加载部分变量的实例讲解

yipeiwu_com5年前Python基础

tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 init_op = tf.global_variables_initializer()
 sess.run(init_op)
 saver.save(sess,"checkpoint/model_test",global_step=1)

当我们保存模型后,我们可以通过saver.restore()来加载模型,初始化变量:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

神经网络训练时,有时候我们需要从预训练的模型中加载部分参数,初始化当前模型,例如加入CNN有6层,我们需要从已有的模型初始化CNN前5层参数.这可以通过saver.restore()实现.

之前我们已经介绍可以通过tf.train.Saver()的保存部分变量的方法,即需要保存的变量列表,同样的,在变量初始化的时候,我们可以对需要单独初始化的变量分别定义一个tf.train.Saver()函数,这样就可以单独对该部分变量初始化,例如下面代码,saver1用于初始化变量v1,saver2用于初始化变量v2,v3:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
#saver = tf.train.Saver()
saver1 = tf.train.Saver([v1])
saver2 = tf.train.Saver([v2]+[v3])
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver1.restore(sess, "checkpoint/model_test-1")
 saver2.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

以上这篇tensorflow 加载部分变量的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python装饰器常见使用方法分析

本文实例讲述了python装饰器常见使用方法。分享给大家供大家参考,具体如下: python 的装饰器,可以用来实现,类似spring AOP 类似的功能。一样可以用来记录某个方法执行前...

对pandas中两种数据类型Series和DataFrame的区别详解

对pandas中两种数据类型Series和DataFrame的区别详解

1. Series相当于数组numpy.array类似 s1=pd.Series([1,2,4,6,7,2]) s2=pd.Series([4,3,1,57,8],index=['a...

不要用强制方法杀掉python线程

前言:     不要试图用强制方法杀掉一个python线程,这从服务设计上就存在不合理性。 多线程本用来任务的协作并发,如果你使用强制手段干掉线程,那么很大...

Python实现插入排序和选择排序的方法

Python实现插入排序和选择排序的方法

话不多说,让我们从最基本的排序算法开始吧 插入排序 如下图所示,插入排序的实现思路顾名思义,就是 不断地在一个已经是有序的数组中,寻找合适位置并插入新元素 。 具体实现步骤为: 首先我...

深入解析Python中的__builtins__内建对象

如果你已经学习了包,模块这些知识了。 你会不会有好奇:Python为什么可以直接使用一些内建函数,不用显式的导入它们,比如 str() int() dir() ...? 原因是Pytho...