关于Pytorch的MLP模块实现方式

yipeiwu_com6年前Python基础

MLP分类效果一般好于线性分类器,即将特征输入MLP中再经过softmax来进行分类。

具体实现为将原先线性分类模块:

self.classifier = nn.Linear(config.hidden_size, num_labels)

替换为:

self.classifier = MLP(config.hidden_size, num_labels)

并且添加MLP模块:

  class MLP(nn.Module):
    def __init__(self, input_size, common_size):
      super(MLP, self).__init__()
      self.linear = nn.Sequential(
        nn.Linear(input_size, input_size // 2),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 2, input_size // 4),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 4, common_size)
      )
 
    def forward(self, x):
      out = self.linear(x)
      return out

看一下模块结构:

mlp = MLP(1000,3)
print(mlp)

以上这篇关于Pytorch的MLP模块实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python使用psutil模块获取系统状态

获取操作系统的当前运行状态和负载情况,是一个系统管理员的基本技能,因为这对我们日常排查故障,定位问题有着非常紧密的联系,比如查看当前系统的基本信息,例如cpu,内存,网络接收包情况,磁盘...

Python实现Mysql数据库连接池实例详解

Python实现Mysql数据库连接池实例详解

python连接Mysql数据库: Python编程中可以使用MySQLdb进行数据库的连接及诸如查询/插入/更新等操作,但是每次连接MySQL数据库请求时,都是独立的去请求访问,相当...

Python空间数据处理之GDAL读写遥感图像

GDAL是空间数据处理的开源包,支持多种数据格式的读写。遥感图像是一种带大地坐标的栅格数据,遥感图像的栅格模型包含以下两部分的内容: 栅格矩阵:由正方形或者矩形栅格点组成,每个栅格点所对...

Python + OpenCV 实现LBP特征提取的示例代码

Python + OpenCV 实现LBP特征提取的示例代码

背景 看了些许的纹理特征提取的paper,想自己实现其中部分算法,看看特征提取之后的效果是怎样 运行环境 Mac OS Python3.0 Anaconda3(集成了很多包...

人脸识别经典算法一 特征脸方法(Eigenface)

人脸识别经典算法一 特征脸方法(Eigenface)

这篇文章是撸主要介绍人脸识别经典方法的第一篇,后续会有其他方法更新。特征脸方法基本是将人脸识别推向真正可用的第一种方法,了解一下还是很有必要的。特征脸用到的理论基础PCA在另一篇博客里:...