tensorflow获取变量维度信息

yipeiwu_com5年前Python基础

tensorflow版本1.4

获取变量维度是一个使用频繁的操作,在tensorflow中获取变量维度主要用到的操作有以下三种:

  • Tensor.shape
  • Tensor.get_shape()
  • tf.shape(input,name=None,out_type=tf.int32)

对上面三种操作做一下简单分析:(这三种操作先记作A、B、C)

A 和 B 基本一样,只不过前者是Tensor的属性变量,后者是Tensor的函数。
A 和 B 均返回TensorShape类型,而 C 返回一个1D的out_type类型的Tensor。
A 和 B 可以在任意位置使用,而 C 必须在Session中使用。
A 和 B 获取的是静态shape,可以返回不完整的shape; C 获取的是动态的shape,必须是完整的shape。

另外,补充从TenaorShape变量中获取具体维度数值的方法

# 直接获取TensorShape变量的第i个维度值
x.shape[i].value
x.get_shape()[i].value

# 将TensorShape变量转化为list类型,然后直接按照索引取值
x.get_shape().as_list()

下面给出全部的示例程序:

import tensorflow as tf

x1 = tf.constant([[1,2,3],[4,5,6]])
# 占位符创建变量,第一个维度初始化为None,表示暂不指定维度
x2 = tf.placeholder(tf.float32,[None, 2,3])
print('x1.shape:',x1.shape)
print('x2.shape:',x2.shape)
print('x2.shape[1].value:',x2.shape[1].value)
print('tf.shape(x1):',tf.shape(x1))
print('tf.shape(x2):',tf.shape(x2))
print('x1.get_shape():',x1.get_shape())
print('x2.get_shape():',x2.get_shape())
print('x2.get_shape.as_list[1]:',x2.get_shape().as_list()[1])
shapeOP1 = tf.shape(x1)
shapeOP2 = tf.shape(x2)
with tf.Session() as sess:
 print('Within session, tf.shape(x1):',sess.run(shapeOP1))
 # 由于x2未进行完整的变量填充,其维度不完整,因此执行下面的命令将会报错
 # print('Within session, tf.shape(x2):',sess.run(shapeOP2)) # 此命令将会报错

输出结果为:

x1.shape: (2, 3)
x2.shape: (?, 2, 3)
x2.shape[1].value: 2
tf.shape(x1): Tensor("Shape:0", shape=(2,), dtype=int32)
tf.shape(x2): Tensor("Shape_1:0", shape=(3,), dtype=int32)
x1.get_shape(): (2, 3)
x2.get_shape(): (?, 2, 3)
x2.get_shape.as_list[1]: 2
Within session, tf.shape(x1): [2 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python区分不同数据类型的方法

python区分不同数据类型的方法

python怎么区分不同数据类型? Python判断变量的数据类型的两种方法 一、Python中的数据类型有数字、字符串,列表、元组、字典、集合等。有两种方法判断一个变量的数据类型 1、...

Python 实现购物商城,含有用户入口和商家入口的示例

这是模拟淘宝的一个简易的购物商城程序。 用户入口具有以下功能: 登录认证 可以锁定用户 密码输入次数大于3次,锁定用户名 连续三次输错用户名退出程序 可以选择直接购买,也可以选择加入购物...

python生成密码字典的方法

python生成密码字典的方法

这里我使用的是python27 主要用的是我之前博文里提到的itertools循环迭代的模块,用这个模块可以省不少事 首先要调用itertools import itertools...

利用Python对文件夹下图片数据进行批量改名的代码实例

利用Python对文件夹下图片数据进行批量改名的代码实例

1. 前言 我们最近在做一个使用flask 模拟 instagram 的图片分享网站, 需要一些基本的图片数据, 我们这里采用的是本地提供, 但是,使用爬虫从网上爬下来的图片,名字都是乱...

python opencv minAreaRect 生成最小外接矩形的方法

python opencv minAreaRect 生成最小外接矩形的方法

使用python opencv返回点集cnt的最小外接矩形,所用函数为 cv2.minAreaRect(cnt) ,cnt是点集数组或向量(里面存放的是点的坐标),并且这个点集不定个数。...