pytorch 在sequential中使用view来reshape的例子

yipeiwu_com6年前Python基础

pytorch中view是tensor方法,然而在sequential中包装的是nn.module的子类,

因此需要自己定义一个方法:

import torch.nn as nn
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args

 def forward(self, x):
  # 如果数据集最后一个batch样本数量小于定义的batch_batch大小,会出现mismatch问题。可以自己修改下,如只传入后面的shape,然后通过x.szie(0),来输入。
  return x.view(self.shape)
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args
 def forward(self, x):
  return x.view((x.size(0),)+self.shape)

以上这篇pytorch 在sequential中使用view来reshape的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

对django的User模型和四种扩展/重写方法小结

User模型 User模型是这个框架的核心部分。他的完整的路径是在django.contrib.auth.models.User。以下对这个User对象做一个简单了解: 字段: 内置的U...

Python面向对象类的继承实例详解

本文实例讲述了Python面向对象类的继承。分享给大家供大家参考,具体如下: 一、概述 面向对象编程 (OOP) 语言的一个主要功能就是“继承”。继承是指这样一种能力:它可以使用现有类的...

快速解决vue.js 模板和jinja 模板冲突的问题

快速解决vue.js 模板和jinja 模板冲突的问题

jinjia和vue.js默认的模板转义符都是{{}} 目前的解决办法是修改vue.js的转义符,将原来的{{}}替换为其他标签,我改为{[]} 版本1.x和2.x方法如下 //...

python opencv 读取本地视频文件 修改ffmpeg的方法

Python + opencv 读取视频的三种情况: 情况一:通过摄像头采集视频 情况二:通过本地视频文件获取视频 情况三:通过摄像头录制视频,再读取录制的视频 摄像头采集、本地视频文件...

在Mac上删除自己安装的Python方法

推荐使用 Homebrew 来安装第三方工具。自己安装的python散落在电脑各处,删除起来比较麻烦。今天在此记录一下删除的过程(本人以Python3.6为例)。 删除Python 3....