pytorch 实现在预训练模型的 input上增减通道

yipeiwu_com6年前Python基础

如何把imagenet预训练的模型,输入层的通道数随心所欲的修改,从而来适应自己的任务

#增加一个通道
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, w[:, :1, :, :]), dim=1))
 
#方式2
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, torch.zeros(64, 1, 7, 7)), dim=1))
 
 
#单通道输入
layers[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(w[:, :1, :, :])

以上这篇pytorch 实现在预训练模型的 input上增减通道就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

在Python文件中指定Python解释器的方法

以下针对Ubuntu系统,Windows系统没有测试过。 Ubuntu中默认就安装有Python 2.x和Python 3.x,默认情况下python命令指的是Python 2.x。因此...

django自定义模板标签过程解析

django自定义模板标签过程解析

这篇文章主要介绍了django自定义模板标签过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 代码布局 自定义模板标签必须位于...

Python中property属性实例解析

本文主要讲述的是对Python中property属性(特性)的理解,具体如下。 定义及作用: 在property类中,有三个成员方法和三个装饰器函数。 三个成员方法分别是:fget、f...

Python字符串格式化

在许多编程语言中都包含有格式化字符串的功能,比如C和Fortran语言中的格式化输入输出。Python中内置有对字符串进行格式化的操作%。 模板 格式化字符串时,Python使用一个字符...

python获取代理IP的实例分享

平时当我们需要爬取一些我们需要的数据时,总是有些网站禁止同一IP重复访问,这时候我们就应该使用代理IP,每次访问前伪装自己,让“敌人”无法察觉。 oooooooooooooooOK,让我...