tensorflow 只恢复部分模型参数的实例

yipeiwu_com6年前Python基础

我就废话不多说了,直接上代码吧!

import tensorflow as tf

def model_1():
  with tf.variable_scope("var_a"):
    a = tf.Variable(initial_value=[1, 2, 3], name="a")

  vars = [var for var in tf.trainable_variables() if var.name.startswith("var_a")]
  print(len(vars))
  return vars

def model_2():

  vars1 = model_1()

  with tf.variable_scope("var_b"):
    a = tf.Variable(initial_value=[1, 2, 3], name="a")

  vars2 = [var for var in tf.trainable_variables() if var.name.startswith("var")]
  print(len(vars2))
  return vars1


def pretrain_model1():
  print("-------- model 1 ------")
  vars = model_1()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.save(sess, "./model.ckpt")

def train_model2():
  print("-------- model 2 ------")

  model1_vars = model_2()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_list=model1_vars)
    saver.restore(sess, "./model.ckpt")
    vars = sess.run([model1_vars])
    for var in vars:
      print(var)

step = 2
if step == 1:
  pretrain_model1()
else:
  train_model2()

以上这篇tensorflow 只恢复部分模型参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python脚本实现代码行数统计代码分享

Python脚本实现代码行数统计代码分享

之前用bash实现过(/post/61943.htm),不过那个不能在windows下使用,所以就写了个python版,也方便我以后使用……这里就不多介绍了,不懂的google下。 实现...

Python 实现的 Google 批量翻译功能

首先声明,没有什么不良动机,因为经常会用 translate.google.cn,就想着用 Python 模拟网页提交实现文档的批量翻译。据说有 API,可是要收费。 生成 Token...

Python批量发送post请求的实现代码

昨天学了一天的Python(我的生产语言是java,也可以写一些shell脚本,算有一点点基础),今天有一个应用场景,就正好练手了。 这个功能之前再java里写过,比较粗糙,原来是在我本...

python实现K最近邻算法

KNN核心算法函数,具体内容如下 #! /usr/bin/env python3 # -*- coding: utf-8 -*- # fileName : KNNdistance.p...

Python检查和同步本地时间(北京时间)的实现方法

背景 有时本地服务器的时间不准了,需要同步互联网上的时间。 解决方案 NTP时间同步,找到一些可用的NTP服务器进行同步即可。 通过获取一些大型网站的时间来同步为自己的时间...