pytorch 在sequential中使用view来reshape的例子

yipeiwu_com6年前Python基础

pytorch中view是tensor方法,然而在sequential中包装的是nn.module的子类,

因此需要自己定义一个方法:

import torch.nn as nn
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args

 def forward(self, x):
  # 如果数据集最后一个batch样本数量小于定义的batch_batch大小,会出现mismatch问题。可以自己修改下,如只传入后面的shape,然后通过x.szie(0),来输入。
  return x.view(self.shape)
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args
 def forward(self, x):
  return x.view((x.size(0),)+self.shape)

以上这篇pytorch 在sequential中使用view来reshape的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

selenium3+python3环境搭建教程图解

selenium3+python3环境搭建教程图解

1、首先安装火狐浏览器 有单独文章分享怎么安装 2、搭建python环境 安装python,安装的时候把path选好,就不用自己在配置,安装方法有单独文档分享 安装好以后cmd打开输入p...

跟老齐学Python之再深点,更懂list

list解析 先看下面的例子,这个例子是想得到1到9的每个整数的平方,并且将结果放在list中打印出来 >>> power2 = [] >>> f...

Python实现的科学计算器功能示例

本文实例讲述了Python实现的科学计算器功能。分享给大家供大家参考,具体如下: import wx import re import math # begin wxGlade: e...

Python序列之list和tuple常用方法以及注意事项

sequence 序列 sequence(序列)是一组有顺序的对象的集合。序列可以包含一个或多个元素,也可以没有任何元素。 我们之前所说的基本数据类型,都可以作为序列的对象。对象还可以是...

解决python中使用plot画图,图不显示的问题

解决python中使用plot画图,图不显示的问题

对以下数据画图结果图不显示,修改过程如下 df3 = {'chinese':109, 'American':88, 'German': 66, 'Korea':23, 'Japan'...