tensorflow 加载部分变量的实例讲解

yipeiwu_com5年前Python基础

tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 init_op = tf.global_variables_initializer()
 sess.run(init_op)
 saver.save(sess,"checkpoint/model_test",global_step=1)

当我们保存模型后,我们可以通过saver.restore()来加载模型,初始化变量:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

神经网络训练时,有时候我们需要从预训练的模型中加载部分参数,初始化当前模型,例如加入CNN有6层,我们需要从已有的模型初始化CNN前5层参数.这可以通过saver.restore()实现.

之前我们已经介绍可以通过tf.train.Saver()的保存部分变量的方法,即需要保存的变量列表,同样的,在变量初始化的时候,我们可以对需要单独初始化的变量分别定义一个tf.train.Saver()函数,这样就可以单独对该部分变量初始化,例如下面代码,saver1用于初始化变量v1,saver2用于初始化变量v2,v3:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
#saver = tf.train.Saver()
saver1 = tf.train.Saver([v1])
saver2 = tf.train.Saver([v2]+[v3])
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver1.restore(sess, "checkpoint/model_test-1")
 saver2.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

以上这篇tensorflow 加载部分变量的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

pygame游戏之旅 载入小车图片、更新窗口

pygame游戏之旅 载入小车图片、更新窗口

本文为大家分享了pygame游戏之旅的第3篇,供大家参考,具体内容如下 载入car图片(我自己画的),需要用到pygame.image模块,定义carImg用于接收载入的图片 car...

django数据库自动重连的方法实例

简介 Django数据库连接超过wait_timeout导致连接丢失时自动重新连接数据库 https://github.com/zhanghaofe...(本地下载) 安装 pi...

Python3生成手写体数字方法

Python3生成手写体数字方法

0.引言   平时上网干啥的基本上都会接触验证码,或者在机器学习学习过程中,大家或许会接触过手写体识别/验证码识别之类问题,会用到手写体的数据集;   自己尝试写了一个生成手写体图片的p...

浅谈Python中的可迭代对象、迭代器、For循环工作机制、生成器

浅谈Python中的可迭代对象、迭代器、For循环工作机制、生成器

1.iterable iterator区别 要了解两者区别,先要了解一下迭代器协议: 迭代器协议是指:对象需要提供__next__()方法,它返回迭代中的元素,在没有更多元素后,抛出St...

使用python检测主机存活端口及检查存活主机

使用python检测主机存活端口及检查存活主机

监测主机存活的端口 #!/usr/bin/env python # coding-utf import argparse import socket import sys #auth...