关于Pytorch的MLP模块实现方式

yipeiwu_com6年前Python基础

MLP分类效果一般好于线性分类器,即将特征输入MLP中再经过softmax来进行分类。

具体实现为将原先线性分类模块:

self.classifier = nn.Linear(config.hidden_size, num_labels)

替换为:

self.classifier = MLP(config.hidden_size, num_labels)

并且添加MLP模块:

  class MLP(nn.Module):
    def __init__(self, input_size, common_size):
      super(MLP, self).__init__()
      self.linear = nn.Sequential(
        nn.Linear(input_size, input_size // 2),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 2, input_size // 4),
        nn.ReLU(inplace=True),
        nn.Linear(input_size // 4, common_size)
      )
 
    def forward(self, x):
      out = self.linear(x)
      return out

看一下模块结构:

mlp = MLP(1000,3)
print(mlp)

以上这篇关于Pytorch的MLP模块实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python安装tar.gz格式文件方法详解

Python安装tar.gz格式文件方法详解

这篇文章主要介绍了Python安装tar.gz格式文件方法详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 有的库没有找到对应的.w...

django orm 通过related_name反向查询的方法

如下所示: class level(models.Model): l_name = models.CharField(max_length=50,verbose_name="等级名...

使用 Python 处理3万多条数据只要几秒钟

使用 Python 处理3万多条数据只要几秒钟

应用场景:工作中经常遇到大量的数据需要整合、去重、按照特定格式导出等情况。如果用 Excel 操作,不仅费时费力,还不准确,有么有更高效的解决方案呢? 本文以17个 txt 文本,3万多...

python 转换 Javascript %u 字符串为python unicode的代码

web采集的数据为 %u6B63%u5F0F%u4EBA%u5458,需要读取并转换为python对象,想了下不调用Javascript去eval,只能自己翻译了。 核心代码: i...

celery4+django2定时任务的实现代码

网上有很多celery + django实现定时任务的教程,不过它们大多数是基于djcelery + celery3的; 或者是使用django_celery_beat配置较为繁琐的。...