关于Pytorch的MLP模块实现方式

yipeiwu_com6年前Python基础

MLP分类效果一般好于线性分类器,即将特征输入MLP中再经过softmax来进行分类。

具体实现为将原先线性分类模块:

self.classifier = nn.Linear(config.hidden_size, num_labels)

替换为:

self.classifier = MLP(config.hidden_size, num_labels)

并且添加MLP模块:

  class MLP(nn.Module):
    def __init__(self, input_size, common_size):
      super(MLP, self).__init__()
      self.linear = nn.Sequential(
        nn.Linear(input_size, input_size // 2),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 2, input_size // 4),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 4, common_size)
      )
 
    def forward(self, x):
      out = self.linear(x)
      return out

看一下模块结构:

mlp = MLP(1000,3)
print(mlp)

以上这篇关于Pytorch的MLP模块实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python 基于Twisted框架的文件夹网络传输源码

Python 基于Twisted框架的文件夹网络传输源码

由于文件夹可能有多层目录,因此需要对其进行递归遍历。 本文采取了简单的协议定制,定义了五条命令,指令Head如下: Sync:标识开始同步文件夹 End:标识结束同步 File:标识传输...

Python中pow()和math.pow()函数用法示例

本文实例讲述了Python中pow()和math.pow()函数用法。分享给大家供大家参考,具体如下: 1. 内置函数pow() >>> help(pow) Hel...

对python 各种删除文件失败的处理方式分享

调用python提供的各种删除文件的操作均失败 返回值5,拒绝访问,但是多次确认文件没有被打开,文件是从一个zip包中解压出来后,没有任何打开读写等操作 最后调用windows的强制删除...

pyspark 读取csv文件创建DataFrame的两种方法

方法一:用pandas辅助 from pyspark import SparkContext from pyspark.sql import SQLContext import...

Python多线程和队列操作实例

Python3,开一个线程,间隔1秒把一个递增的数字写入队列,再开一个线程,从队列中取出数字并打印到终端 复制代码 代码如下: #! /usr/bin/env python3 impor...