pytorch forward两个参数实例

yipeiwu_com5年前Python基础

以channel Attention Block为例子

class CAB(nn.Module):
 
  def __init__(self, in_channels, out_channels):
    super(CAB, self).__init__()
    self.global_pooling = nn.AdaptiveAvgPool2d(output_size=1)
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
    self.relu = nn.ReLU()
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
    self.sigmod = nn.Sigmoid()
 
  def forward(self, x):
    x1, x2 = x # high, low
    x = torch.cat([x1,x2],dim=1)
    x = self.global_pooling(x)
    x = self.conv1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.sigmod(x)
    x2 = x * x2
    res = x2 + x1
    return res

以上这篇pytorch forward两个参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Anaconda多环境多版本python配置操作方法

conda测试指南 在开始这个conda测试之前,你应该已经下载并安装好了Anaconda或者Miniconda 注意:在安装之后,你应该关闭并重新打开windows命令行。 一、Con...

pandas系列之DataFrame 行列数据筛选实例

pandas系列之DataFrame 行列数据筛选实例

一、对DataFrame的认知 DataFrame的本质是行(index)列(column)索引+多列数据。 为了简化理解,我们不妨换个思路… 现实中,为了简化对一件事物的描述,我们会...

python将txt文档每行内容循环插入数据库的方法

如下所示: import pymysql import time import re def get_raw_label(rece): re1 = r'"([\s\S]*?...

windows下python之mysqldb模块安装方法

windows下python之mysqldb模块安装方法

之所以会写下这篇日志,是因为安装的过程有点虐心。目前这篇文章是针对windows操作系统上的mysqldb的安装。安装python的mysqldb模块,首先当然是找一些官方的网站去下载:...

python实现读取大文件并逐行写入另外一个文件

<pre name="code" class="python">creazy.txt文件有4G,逐行读取其内容并写入monday.txt文件里。 def crea...