关于Pytorch的MLP模块实现方式

yipeiwu_com6年前Python基础

MLP分类效果一般好于线性分类器,即将特征输入MLP中再经过softmax来进行分类。

具体实现为将原先线性分类模块:

self.classifier = nn.Linear(config.hidden_size, num_labels)

替换为:

self.classifier = MLP(config.hidden_size, num_labels)

并且添加MLP模块:

  class MLP(nn.Module):
    def __init__(self, input_size, common_size):
      super(MLP, self).__init__()
      self.linear = nn.Sequential(
        nn.Linear(input_size, input_size // 2),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 2, input_size // 4),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 4, common_size)
      )
 
    def forward(self, x):
      out = self.linear(x)
      return out

看一下模块结构:

mlp = MLP(1000,3)
print(mlp)

以上这篇关于Pytorch的MLP模块实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python 列表(List) 的三种遍历方法实例 详解

Python 列表(List) 的三种遍历方法实例 详解

Python 遍历 最近学习python这门语言,感觉到其对自己的工作效率有很大的提升,下面废话不多说,直接贴代码 #!/usr/bin/env python # -*- codin...

Python subprocess库的使用详解

介绍 使用subprocess模块的目的是用于替换os.system等一些旧的模块和方法。 运行python的时候,我们都是在创建并运行一个进程。像Linux进程那样,一个进程可以f...

python访问mysql数据库的实现方法(2则示例)

本文实例讲述了python访问mysql数据库的实现方法。分享给大家供大家参考,具体如下: 首先安装与Python版本匹配的MySQLdb 示例一 import MySQLdb co...

简单讲解Python中的字符串与字符串的输入输出

字符串 字符串用''或者""括起来,如果字符串内部有‘或者",需要使用\进行转义 >>> print 'I\'m ok.' I'm ok. 转义字符\可以转义...

Python实现对PPT文件进行截图操作的方法

本文实例讲述了Python实现对PPT文件进行截图操作的方法。分享给大家供大家参考。具体分析如下: 下面的代码可以为powerpoint文件ppt进行截图,可以指定要截取的幻灯片页面,需...