对Pytorch中nn.ModuleList 和 nn.Sequential详解

yipeiwu_com5年前Python基础

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

基于Django的乐观锁与悲观锁解决订单并发问题详解

前言 订单并发这个问题我想大家都是有一定认识的,这里我说一下我的一些浅见,我会尽可能的让大家了解如何解决这类问题。 在解释如何解决订单并发问题之前,需要先了解一下什么是数据库的事务。(我...

详解python pandas 分组统计的方法

详解python pandas 分组统计的方法

首先,看看本文所面向的应用场景:我们有一个数据集df,现在想统计数据中某一列每个元素的出现次数。这个在我们前面文章《如何画直方图》中已经介绍了方法,利用value_counts()就可以...

浅谈flask源码之请求过程

Flask Flask是什么? Flask是一个使用 Python 编写的轻量级 Web 应用框架, 让我们可以使用Python语言快速搭建Web服务, Flask也被称为 "m...

Django 接收Post请求数据,并保存到数据库的实现方法

要说基本操作,大家基本都会,但是有时候,有些操作使用小技巧会节省很多时间。 本篇描述的就是使用dict小技巧,保存到数据库,用来节省大家编码的工作量。 主要内容:通过for循环拿到pos...

Python通过matplotlib画双层饼图及环形图简单示例

Python通过matplotlib画双层饼图及环形图简单示例

(1) 饼图(pie),即在一个圆圈内分成几块,显示不同数据系列的占比大小,这也是我们在日常数据的图形展示中最常用的图形之一。 在python中常用matplotlib的pie来绘制,基...