pytorch 实现在预训练模型的 input上增减通道

yipeiwu_com6年前Python基础

如何把imagenet预训练的模型,输入层的通道数随心所欲的修改,从而来适应自己的任务

#增加一个通道
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, w[:, :1, :, :]), dim=1))
 
#方式2
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, torch.zeros(64, 1, 7, 7)), dim=1))
 
 
#单通道输入
layers[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(w[:, :1, :, :])

以上这篇pytorch 实现在预训练模型的 input上增减通道就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python关键字and和or用法实例

python 中的and从左到右计算表达式,若所有值均为真,则返回最后一个值,若存在假,返回第一个假值。 or也是从左到有计算表达式,返回第一个为真的值。 复制代码 代码如下: IDLE...

对python 生成拼接xml报文的示例详解

最近临时工作要生成xml报名,通过MQ接口发送。简单小程序。 自增长拼成xml报文 Test_001.py # encoding=utf-8 import time orderI...

Python 3 实现定义跨模块的全局变量和使用教程

尽管某些书籍上总是说避免使用全局变量,但是在实际的需求不断变化中,往往定义一个全局变量是最可靠的方法,但是又必须要避免变量名覆盖。 Python 中 global 关键字可以定义一个变量...

python base64库给用户名或密码加密的流程

给明文密码加密的流程: import base64 pwd_after_encrypt = base64.b64encode(b'this is a scret!') pwd_bef...

Python3.7 dataclass使用指南小结

dataclass简介 dataclass的定义位于PEP-557,根据定义一个dataclass是指“一个带有默认值的可变的namedtuple”,广义的定义就是有一个类,它的属性均可...