pytorch 实现在预训练模型的 input上增减通道

yipeiwu_com5年前Python基础

如何把imagenet预训练的模型,输入层的通道数随心所欲的修改,从而来适应自己的任务

#增加一个通道
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, w[:, :1, :, :]), dim=1))
 
#方式2
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, torch.zeros(64, 1, 7, 7)), dim=1))
 
 
#单通道输入
layers[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(w[:, :1, :, :])

以上这篇pytorch 实现在预训练模型的 input上增减通道就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python使用QRCode模块生成二维码实例详解

Python使用QRCode模块生成二维码 QRCode官网 https://pypi.python.org/pypi/qrcode/5.1 简介 python-qrcode是个用来...

OpenCV实现人脸识别

主要有以下步骤: 1、人脸检测 2、人脸预处理 3、从收集的人脸训练机器学习算法 4、人脸识别 5、收尾工作 人脸检测算法: 基于Haar的脸部检测器的基本思想是,对于面部正面大部分区域...

python解析中国天气网的天气数据

使用方法:terminal中输入复制代码 代码如下:python weather.py http://www.weather.com.cn/weather/101010100.shtml...

Python3 修改默认环境的方法

Mac 环境中既有自带的 Python2.7 也有自己安装的 Python 3.5.1,默认想用 Python3 的环境 1. 添加 Python3 的环境变量 vi ~/.bash...

关于pycharm中pip版本10.0无法使用的解决办法

关于pycharm中pip版本10.0无法使用的解决办法

一、背景: 近期在利用 pycharm 安装第三方库时会提示 pip 不是最新版本, 因此对 pip 进行更新,但是生成最新版本之后, pip 中由于缺少 main 函...