tensorflow 只恢复部分模型参数的实例

yipeiwu_com5年前Python基础

我就废话不多说了,直接上代码吧!

import tensorflow as tf

def model_1():
  with tf.variable_scope("var_a"):
    a = tf.Variable(initial_value=[1, 2, 3], name="a")

  vars = [var for var in tf.trainable_variables() if var.name.startswith("var_a")]
  print(len(vars))
  return vars

def model_2():

  vars1 = model_1()

  with tf.variable_scope("var_b"):
    a = tf.Variable(initial_value=[1, 2, 3], name="a")

  vars2 = [var for var in tf.trainable_variables() if var.name.startswith("var")]
  print(len(vars2))
  return vars1


def pretrain_model1():
  print("-------- model 1 ------")
  vars = model_1()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.save(sess, "./model.ckpt")

def train_model2():
  print("-------- model 2 ------")

  model1_vars = model_2()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_list=model1_vars)
    saver.restore(sess, "./model.ckpt")
    vars = sess.run([model1_vars])
    for var in vars:
      print(var)

step = 2
if step == 1:
  pretrain_model1()
else:
  train_model2()

以上这篇tensorflow 只恢复部分模型参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

实例讲解python中的协程

python协程 线程和进程的操作是由程序触发系统接口,最后的执行者是系统;协程的操作则是程序员。 协程存在的意义:对于多线程应用,CPU通过切片的方式来切换线程间的执行,线程切换时需要...

python实现多进程按序号批量修改文件名的方法示例

本文实例讲述了python实现多进程按序号批量修改文件名的方法。分享给大家供大家参考,具体如下: 说明 文件名命名方式如图,是数字序号开头,但是中间有些文件删掉了,序号不连续,这里将序号...

Windows下安装Scrapy

Windows下安装Scrapy

这几天正好有需求实现一个爬虫程序,想到爬虫程序立马就想到了python,python相关的爬虫资料好像也特别多。于是就决定用python来实现爬虫程序了,正好发现了python有一个开源...

简单介绍Python中的round()方法

 round()方法返回 x 的小数点四舍五入到n个数字。 语法 以下是round()方法的语法: round( x [, n] ) 参数  &nbs...

Python 3 实现定义跨模块的全局变量和使用教程

尽管某些书籍上总是说避免使用全局变量,但是在实际的需求不断变化中,往往定义一个全局变量是最可靠的方法,但是又必须要避免变量名覆盖。 Python 中 global 关键字可以定义一个变量...