pytorch 在sequential中使用view来reshape的例子

yipeiwu_com6年前Python基础

pytorch中view是tensor方法,然而在sequential中包装的是nn.module的子类,

因此需要自己定义一个方法:

import torch.nn as nn
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args

 def forward(self, x):
  # 如果数据集最后一个batch样本数量小于定义的batch_batch大小,会出现mismatch问题。可以自己修改下,如只传入后面的shape,然后通过x.szie(0),来输入。
  return x.view(self.shape)
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args
 def forward(self, x):
  return x.view((x.size(0),)+self.shape)

以上这篇pytorch 在sequential中使用view来reshape的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python创建线程示例

复制代码 代码如下:import threadingfrom time import sleep def test_func(id):    for i i...

Python Django简单实现session登录注销过程详解

Python Django简单实现session登录注销过程详解

开发工具:pycharm 简单实现session的登录注销功能 Django配置好路由分发功能 默认session在Django里面的超时时间是两周 使用request.session....

Python3 使用map()批量的转换数据类型,如str转float的实现

我们知道map() 会根据提供的函数对指定序列做映射。 第一个参数 function 以参数序列中的每一个元素调用 function 函数,返回包含每次 function 函数返回值的新...

浅析python3字符串格式化format()函数的简单用法

 format()函数 """ 测试 format()函数 """ def testFormat(): # format()函数中有几个元素,前面格式化的字符串中就要有...

Python中MYSQLdb出现乱码的解决方法

本文实例讲述了Python中MYSQLdb出现乱码的解决方法,分享给大家供大家参考。具体方法如下: 一般来说,在使用mysql最麻烦的问题在于乱码。 查看mysql的编码: 命令:&nb...