关于Pytorch的MLP模块实现方式

yipeiwu_com6年前Python基础

MLP分类效果一般好于线性分类器,即将特征输入MLP中再经过softmax来进行分类。

具体实现为将原先线性分类模块:

self.classifier = nn.Linear(config.hidden_size, num_labels)

替换为:

self.classifier = MLP(config.hidden_size, num_labels)

并且添加MLP模块:

  class MLP(nn.Module):
    def __init__(self, input_size, common_size):
      super(MLP, self).__init__()
      self.linear = nn.Sequential(
        nn.Linear(input_size, input_size // 2),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 2, input_size // 4),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 4, common_size)
      )
 
    def forward(self, x):
      out = self.linear(x)
      return out

看一下模块结构:

mlp = MLP(1000,3)
print(mlp)

以上这篇关于Pytorch的MLP模块实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python CVXOPT模块安装及使用解析

Python CVXOPT模块安装及使用解析

Python中支持Convex Optimization(凸规划)的模块为CVXOPT,其安装方式为: 卸载原Pyhon中的Numpy 安装CVXOPT的whl文件,链接为:https...

Python3实现发送QQ邮件功能(附件)

本文实例为大家分享了Python3实现发送QQ邮件功能:附件,供大家参考,具体内容如下 可以成功发送邮件附件,但是邮件主要内容无法发送,有空再去找找原因 import smtplib...

python机器学习库scikit-learn:SVR的基本应用

python机器学习库scikit-learn:SVR的基本应用

scikit-learn是python的第三方机器学习库,里面集成了大量机器学习的常用方法。例如:贝叶斯,svm,knn等。 scikit-learn的官网 : http://sciki...

使用k8s部署Django项目的方法步骤

接触了一下docker和k8s,感觉是非常不错的东西。能够方便的部署线上环境,而且还能够更好的利用机器的资源,感觉是以后的大趋势。最近刚好有一个基于django的项目,所以就把这个项目打...

python读取各种文件数据方法解析

python读取各种文件数据方法解析

python读取.txt(.log)文件 、.xml 文件 、excel文件数据,并将数据类型转换为需要的类型,添加到list中详解 1.读取文本文件数据(.txt结尾的文件)或日志文件...