tensorflow tf.train.batch之数据批量读取方式

yipeiwu_com5年前Python基础

在进行大量数据训练神经网络的时候,可能需要批量读取数据。于是参考了这篇文章的代码,结果发现数据一直批量循环输出,不会在数据的末尾自动停止。

然后发现这篇博文说slice_input_producer()这个函数有一个形参num_epochs,通过设置它的值就可以控制全部数据循环输出几次。

于是我设置之后出现以下的报错:

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value input_producer/input_producer/limit_epochs/epochs

     [[Node: input_producer/input_producer/limit_epochs/CountUpTo = CountUpTo[T=DT_INT64, _class=["loc:@input_producer/input_producer/limit_epochs/epochs"], limit=2, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer/input_producer/limit_epochs/epochs)]]

找了好久,都不知道为什么会错,于是只好去看看slice_input_producer()函数的源码,结果在源码中发现作者说这个num_epochs如果不是空的话,就是一个局部变量,需要先调用global_variables_initializer()函数初始化。

于是我调用了之后,一切就正常了,特此记录下来,希望其他人遇到的时候能够及时找到原因。

哈哈,这是笔者第一次通过阅读源码解决了问题,心情还是有点小激动。啊啊,扯远了,上最终成功的代码:

import pandas as pd
import numpy as np
import tensorflow as tf


def generate_data():
  num = 25
  label = np.asarray(range(0, num))
  images = np.random.random([num, 5])
  print('label size :{}, image size {}'.format(label.shape, images.shape))
  return images,label

def get_batch_data():
  label, images = generate_data()
  input_queue = tf.train.slice_input_producer([images, label], shuffle=False,num_epochs=2)
  image_batch, label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False)
  return image_batch,label_batch


images,label = get_batch_data()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())#就是这一行
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try:
  while not coord.should_stop():
    i,l = sess.run([images,label])
    print(i)
    print(l)
except tf.errors.OutOfRangeError:
  print('Done training')
finally:
  coord.request_stop()
coord.join(threads)
sess.close()

以上这篇tensorflow tf.train.batch之数据批量读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python 实现网上商城,转账,存取款等功能的信用卡系统

python 实现网上商城,转账,存取款等功能的信用卡系统

一、要求 二、思路 1.购物类buy 接收 信用卡类 的信用卡可用可用余额, 返回消费金额 2.信用卡(ATM)类 接收上次操作后,信用卡可用余额,总欠款,剩余欠款,存款 其中: 1...

python实现登录密码重置简易操作代码

需求: 1.用户输入密码正确登录 2.用户输入密码错误退出并调用函数继续输入 3.用户输入密码符合原先给定的一个值时,允许用户重置密码,并且可以用新密码登录 4.输入三次后禁止输入 虽然...

Python中asyncio与aiohttp入门教程

Python中asyncio与aiohttp入门教程

很多朋友对异步编程都处于“听说很强大”的认知状态。鲜有在生产项目中使用它。而使用它的同学,则大多数都停留在知道如何使用 Tornado、Twisted、Gevent 这类异步框架上,出现...

Pytorch 计算误判率,计算准确率,计算召回率的例子

无论是官方文档还是各位大神的论文或搭建的网络很多都是计算准确率,很少有计算误判率, 下面就说说怎么计算准确率以及误判率、召回率等指标 1.计算正确率 获取每批次的预判正确个数 train...

python使用xlrd模块读取xlsx文件中的ip方法

程序中经常需要使用excel文件,批量读取文件中的数据 python读取excel文件可以使用xlrd模块 pip install xlrd安装模块 示例: #coding=utf8...