pytorch 预训练层的使用方法

yipeiwu_com5年前Python基础

pytorch 预训练层的使用方法

将其他地方训练好的网络,用到新的网络里面

加载预训练网络

1.原先已经训练好一个网络 AutoEncoder_FC()

2.首先加载该网络,读取其存储的参数

3.设置一个参数集

cnnpre = AutoEncoder_FC()
cnnpre.load_state_dict(torch.load('autoencoder_FC.pkl')['state_dict'])
cnnpre_dict =cnnpre.state_dict()

加载新网络

1.设置新的网络

2.设置新网络参数集

cnn= AutoEncoder()
cnn_dict = cnn.state_dict()

更新新网络参数

1.将两个参数集比对,存在的网络参数保留

2.使用保留下的参数更新新网络参数集

3.加载新网络参数集到新网络中

cnnpre_dict = {k: v for k, v in cnnpre_dict.items() if k in cnn_dict}
cnn_dict.update(cnnpre_dict)
cnn.load_state_dict(cnn_dict)

以上这篇pytorch 预训练层的使用方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python引用(import)文件夹下的py文件的方法

Python引用(import)文件夹下的py文件的方法

Python的import包含文件功能就跟PHP的include类似,但更确切的说应该更像是PHP中的require,因为Python里的import只要目标不存在就报错程序无法往下执行...

python微信跳一跳系列之色块轮廓定位棋盘

python微信跳一跳系列之色块轮廓定位棋盘

在前几篇博文中,我们分别采用颜色识别,模板匹配,像素遍历等方法实现了棋子和棋盘的定位,具体内容可以参见我的前面的文章内容,在这一篇中,我们来探索一种定位棋盘的新方法。 分析 经过...

Python的Flask框架中使用Flask-SQLAlchemy管理数据库的教程

Python的Flask框架中使用Flask-SQLAlchemy管理数据库的教程

使用Flask-SQLAlchemy管理数据库 Flask-SQLAlchemy是一个Flask扩展,它简化了在Flask应用程序中对SQLAlchemy的使用。SQLAlchemy是一...

对Pytorch中Tensor的各种池化操作解析

AdaptiveAvgPool1d(N) 对一个C*H*W的三维输入Tensor, 池化输出为C*H*N, 即按照H轴逐行对W轴平均池化 >>> a = torch...

10分钟教你用python动画演示深度优先算法搜寻逃出迷宫的路径

10分钟教你用python动画演示深度优先算法搜寻逃出迷宫的路径

深度优先算法(DFS 算法)是什么? 寻找起始节点与目标节点之间路径的算法,常用于搜索逃出迷宫的路径。主要思想是,从入口开始,依次搜寻周围可能的节点坐标,但不会重复经过同一个节点,且不能...