tensorflow 用矩阵运算替换for循环 用tf.tile而不写for的方法

yipeiwu_com6年前Python基础

如下所示:

# u [32,30,200]
# u_logits [400,32,30]
q_j_400 = [] 
for j in range(400):
 q_j_400.append(tf.squeeze(tf.matmul(tf.transpose(u,[0,2,1]),tf.expand_dims(tf.nn.softmax(u_logits[j]),-1)),[2])) # tf.matmul [32,200,30],[32,30,1]
test_result = tf.stack(q_j_400)
test_result = tf.transpose(test_result,[1,0,2])

可以通过tf.tile实现更高速的版本

# u [32,30,200]
# u_logits [32,400,30]
u_tile = tf.tile(tf.expand_dims(u,1),[1,400,1,1])
u_logits = tf.expand_dims(tf.nn.softmax(u_logits,-1),-1)
test_result = tf.reduce_sum(u_logits * u_tile,-2) # [32,400,30,1]*[32,400,30,200]

以上这篇tensorflow 用矩阵运算替换for循环 用tf.tile而不写for的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python 写入csv乱码问题解决方法

需求背景 最近为公司开发了一套邮件日报程序,邮件一般就是表格,图片,然后就是附件。附件一般都是默认写到txt文件里,但是PM希望邮件里的附件能直接用Excel这种软件打开,最开始想保存...

python实现翻转棋游戏(othello)

python实现翻转棋游戏(othello)

利用上一篇的框架,再写了个翻转棋的程序,为了调试minimax算法,花了两天的时间。 几点改进说明: 拆分成四个文件:board.py,player.py,ai.py,othell...

设置python3为默认python的方法

设置python3为默认python的方法

我们知道在Windows下多版本共存的配置方法就是改可执行文件的名字,配置环境变量。 Linux中的配置原理差不多,思路就是生成软链接,配置到环境变量。 在没配置之前,我的Ubuntu中...

python 标准差计算的实现(std)

numpy.std() 求标准差的时候默认是除以 n 的,即是有偏的,np.std无偏样本标准差方式为加入参数 ddof = 1; pandas.std() 默认是除以n-1 的,即是...

Python循环实现n的全排列功能

描述: 输入一个大于0的整数n,输出1到n的全排列: 例如: n=3,输出[[3, 2, 1], [2, 3, 1], [2, 1, 3], [3, 1, 2], [1, 3, 2]...