pytorch 实现在预训练模型的 input上增减通道

yipeiwu_com6年前Python基础

如何把imagenet预训练的模型,输入层的通道数随心所欲的修改,从而来适应自己的任务

#增加一个通道
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, w[:, :1, :, :]), dim=1))
 
#方式2
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, torch.zeros(64, 1, 7, 7)), dim=1))
 
 
#单通道输入
layers[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(w[:, :1, :, :])

以上这篇pytorch 实现在预训练模型的 input上增减通道就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python斐波那契数列的计算方法

题目: 计算斐波那契数列。具体什么是斐波那契数列,那就是0,1,1,2,3,5,8,13,21,34,55,89,144,233。 要求: 时间复杂度尽可能少 分析: 给出了...

浅谈Series和DataFrame中的sort_index方法

Series 的 sort_index(ascending=True) 方法可以对 index 进行排序操作,ascending 参数用于控制升序或降序,默认为升序。 若要按值对 Ser...

WIn10+Anaconda环境下安装PyTorch(避坑指南)

WIn10+Anaconda环境下安装PyTorch(避坑指南)

这些天安装 PyTorch,遇到了一些坑,特此总结一下,以免忘记。分享给大家。 首先,安装环境是:操作系统 Win10,已经预先暗转了 Anaconda。 1. 为 PyTorch 创建...

python实现2048小游戏

python实现2048小游戏

2048的python实现。修改自某网友的代码,解决了原网友版本的两个小bug: 1. 原版游戏每次只消除一次,而不是递归消除。如 [2 ,2 ,2 ,2] 左移动的话应该是 [4,...

python实现12306火车票查询器

python实现12306火车票查询器

12306火车票购票软件大家都用过,怎么用Python写一个命令行的火车票查看器,要求在命令行敲一行命令来获得你想要的火车票信息,下面通过本文学习吧。 Python火车票查询器 接口...