pytorch 在sequential中使用view来reshape的例子

yipeiwu_com5年前Python基础

pytorch中view是tensor方法,然而在sequential中包装的是nn.module的子类,

因此需要自己定义一个方法:

import torch.nn as nn
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args

 def forward(self, x):
  # 如果数据集最后一个batch样本数量小于定义的batch_batch大小,会出现mismatch问题。可以自己修改下,如只传入后面的shape,然后通过x.szie(0),来输入。
  return x.view(self.shape)
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args
 def forward(self, x):
  return x.view((x.size(0),)+self.shape)

以上这篇pytorch 在sequential中使用view来reshape的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python实现FTP文件传输的实例

Python实现FTP文件传输的实例

FTP一般流程 FTP对应PASV和PORT两种访问方式,分别为被动和主动,是针对FTP服务器端进行区分的,正常传输过程中21号端口用于指令传输,数据传输端口使用其他端口。 PASV:...

python Spyder界面无法打开的解决方法

python Spyder界面无法打开的解决方法

Spyder本来还用得好好的,能正常使用,后来再关闭打开时,出现下面的蜘蛛网界面后,就无法显示操作界面了: 后来在网上搜索了多种方法,甚至还将Adaconda2重装了都没有用。 后来找...

python SQLAlchemy 中的Engine详解

python SQLAlchemy 中的Engine详解

先看这张图,这是从官方网站扒下来的。 Engine 翻译过来就是引擎的意思,汽车通过引擎来驱动,而 SQLAlchemy 是通过 Engine 来驱动,Engine 维护了一个连接池(...

python print 按逗号或空格分隔的方法

1)按,分隔 a, b = 0, 1 while b < 1000: print(b, end=',') a, b = b, a+b 1,1,2,3,5,8,13,...

DJANGO-ALLAUTH社交用户系统的安装配置

DJANGO-ALLAUTH是github上面排名较高的django user系统.本来通过对比是想选用django-userea的,可是博主智商不够看懂它的安装配置文档.....搞乱了...