TensorFlow用expand_dim()来增加维度的方法

yipeiwu_com6年前Python基础

TensorFlow中,想要维度增加一维,可以使用tf.expand_dims(input, dim, name=None)函数。当然,我们常用tf.reshape(input, shape=[])也可以达到相同效果,但是有些时候在构建图的过程中,placeholder没有被feed具体的值,这时就会包下面的错误:TypeError: Expected binary or unicode string, got 1

在这种情况下,我们就可以考虑使用expand_dims来将维度加1。比如我自己代码中遇到的情况,在对图像维度降到二维做特定操作后,要还原成四维[batch, height, width, channels],前后各增加一维。如果用reshape,则因为上述原因报错

one_img2 = tf.reshape(one_img, shape=[1, one_img.get_shape()[0].value, one_img.get_shape()[1].value, 1])

用下面的方法可以实现:

one_img = tf.expand_dims(one_img, 0)
one_img = tf.expand_dims(one_img, -1) #-1表示最后一维

在最后,给出官方的例子和说明

# 't' is a tensor of shape [2]
shape(expand_dims(t, 0)) ==> [1, 2]
shape(expand_dims(t, 1)) ==> [2, 1]
shape(expand_dims(t, -1)) ==> [2, 1]

# 't2' is a tensor of shape [2, 3, 5]
shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]

Args:

input: A Tensor.
dim: A Tensor. Must be one of the following types: int32, int64. 0-D (scalar). Specifies the dimension index at which to expand the shape of input.
name: A name for the operation (optional).

Returns:

A Tensor. Has the same type as input. Contains the same data as input, but its shape has an additional dimension of size 1 added.

以上这篇TensorFlow用expand_dim()来增加维度的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python urllib urlopen()对象方法/代理的补充说明

python urllib urlopen()对象方法/代理的补充说明 urllib 是 python 自带的一个抓取网页信息一个接口,他最主要的方法是 urlopen(),是基于 py...

Python绘制堆叠柱状图的实例

Python绘制堆叠柱状图的实例

有个朋友要求帮忙绘制堆叠柱状图,查阅了一些文档之后也算是完成了,只是一个小demo,下面我就记录一下。 1.什么是堆叠柱状图 与并排显示分类的分组柱状图不同,堆叠柱状图将每个柱子进行分割...

对python借助百度云API对评论进行观点抽取的方法详解

对python借助百度云API对评论进行观点抽取的方法详解

通过百度云API接口抽取得到产品评论的观点,也掠去了很多评论中无用的内容以及符号,为后续进行文本主题挖掘或者规则的提取提供基础。 工具 1、百度云账号,申请应用接口(自然语言处理) 2...

python Web开发你要理解的WSGI & uwsgi详解

python Web开发你要理解的WSGI & uwsgi详解

WSGI协议 首先弄清下面几个概念: WSGI:全称是Web Server Gateway Interface,WSGI不是服务器,python模块,框架,API或者任何软件,只是一种...

matplotlib.pyplot画图并导出保存的实例

我就废话不多说了,直接上代码吧! import pandas as pd import numpy as np import matplotlib.pyplot as plt fig...