pytorch 在sequential中使用view来reshape的例子

yipeiwu_com6年前Python基础

pytorch中view是tensor方法,然而在sequential中包装的是nn.module的子类,

因此需要自己定义一个方法:

import torch.nn as nn
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args

 def forward(self, x):
  # 如果数据集最后一个batch样本数量小于定义的batch_batch大小,会出现mismatch问题。可以自己修改下,如只传入后面的shape,然后通过x.szie(0),来输入。
  return x.view(self.shape)
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args
 def forward(self, x):
  return x.view((x.size(0),)+self.shape)

以上这篇pytorch 在sequential中使用view来reshape的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

详解Python中的装饰器、闭包和functools的教程

装饰器(Decorators) 装饰器是这样一种设计模式:如果一个类希望添加其他类的一些功能,而不希望通过继承或是直接修改源代码实现,那么可以使用装饰器模式。简单来说Python中的装饰...

基于python二叉树的构造和打印例子

写在最前面: 带你从最简单的二叉树构造开始,深入理解二叉树的数据结构,ps:不会数据结构的程序猿只能是三流的 首先,我们构造一个二叉树 这是最标准,也是最简单的二叉树构造方法 '''...

Python对象中__del__方法起作用的条件详解

对象的__del__是对象在被gc消除回收的时候起作用的一个方法,它的执行一般也就意味着对象不能够继续引用。 示范代码如下: class Demo: def __del__(sel...

python3+PyQt5+Qt Designer实现堆叠窗口部件

python3+PyQt5+Qt Designer实现堆叠窗口部件

本文是对《Python Qt GUI快速编程》的第9章的堆叠窗口例子Vehicle Rental用Python3+PyQt5+Qt Designer进行改写。 第一部分无借用Qt De...

Python面向对象程序设计类的封装与继承用法示例

本文实例讲述了Python面向对象程序设计类的封装与继承用法。分享给大家供大家参考,具体如下: 访问限制(封装) 1、概念 面向对象语言的三大特征:封装, 继承, 多态。 广义的封装:...