关于Pytorch的MLP模块实现方式

yipeiwu_com6年前Python基础

MLP分类效果一般好于线性分类器,即将特征输入MLP中再经过softmax来进行分类。

具体实现为将原先线性分类模块:

self.classifier = nn.Linear(config.hidden_size, num_labels)

替换为:

self.classifier = MLP(config.hidden_size, num_labels)

并且添加MLP模块:

  class MLP(nn.Module):
    def __init__(self, input_size, common_size):
      super(MLP, self).__init__()
      self.linear = nn.Sequential(
        nn.Linear(input_size, input_size // 2),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 2, input_size // 4),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 4, common_size)
      )
 
    def forward(self, x):
      out = self.linear(x)
      return out

看一下模块结构:

mlp = MLP(1000,3)
print(mlp)

以上这篇关于Pytorch的MLP模块实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python 数据清洗之数据合并、转换、过滤、排序

python 数据清洗之数据合并、转换、过滤、排序

前面我们用pandas做了一些基本的操作,接下来进一步了解数据的操作, 数据清洗一直是数据分析中极为重要的一个环节。 数据合并 在pandas中可以通过merge对数据进行合并操作。...

Python OpenCV中的resize()函数的使用

Python OpenCV中的resize()函数的使用

改变图像大小意味着改变尺寸,无论是单独的高或宽,还是两者。也可以按比例调整图像大小。 这里将介绍resize()函数的语法及实例。 语法 函数原型 cv2.resize(src, d...

Python 实现异步调用函数的示例讲解

async_call.py #coding:utf-8 from threading import Thread def async_call(fn): def wrapper...

Pytorch中的VGG实现修改最后一层FC

https://discuss.pytorch.org/t/how-to-modify-the-final-fc-layer-based-on-the-torch-model/766/1...

详解Python列表赋值复制深拷贝及5种浅拷贝

概述 在列表复制这个问题,看似简单的复制却有着许多的学问,尤其是对新手来说,理所当然的事情却并不如意,比如列表的赋值、复制、浅拷贝、深拷贝等绕口的名词到底有什么区别和作用呢? 列表赋值...