pytorch forward两个参数实例

yipeiwu_com6年前Python基础

以channel Attention Block为例子

class CAB(nn.Module):
 
  def __init__(self, in_channels, out_channels):
    super(CAB, self).__init__()
    self.global_pooling = nn.AdaptiveAvgPool2d(output_size=1)
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
    self.relu = nn.ReLU()
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
    self.sigmod = nn.Sigmoid()
 
  def forward(self, x):
    x1, x2 = x # high, low
    x = torch.cat([x1,x2],dim=1)
    x = self.global_pooling(x)
    x = self.conv1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.sigmod(x)
    x2 = x * x2
    res = x2 + x1
    return res

以上这篇pytorch forward两个参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python查看列的唯一值方法

查看某一列中有多少中取值: 数据集名.drop_duplicates(['列名']) #实际为删除重复项,删除后对原数据集不修改 输入:data.drop_duplicates(['na...

Python3实现带附件的定时发送邮件功能

本文实例为大家分享了Python3定时发送邮件功能的具体代码,供大家参考,具体内容如下 1、 导入模块 import os import datetime #定时发送,以及日期 i...

通过pycharm使用git的步骤(图文详解)

通过pycharm使用git的步骤(图文详解)

前言 使用git+pycharm有一段时间了,算是稍有点心得,这边整理一下,可能有的方法不是最优,欢迎交流,可能还是习惯敲命令去使用git,不过其实pycharm已经帮忙做了很多了,我们...

Python缓存技术实现过程详解

一段非常简单代码 普通调用方式 def console1(a, b): print("进入函数") return (a, b) print(console1(3, 'a...

Python嵌套式数据结构实例浅析

本文实例讲述了Python嵌套式数据结构。分享给大家供大家参考,具体如下: 嵌套式数据结构指的是:字典存储在列表中, 或者列表作为值存储在字典中。甚至还可以在字典中嵌套字典。 1 字典列...