pytorch 预训练层的使用方法

yipeiwu_com6年前Python基础

pytorch 预训练层的使用方法

将其他地方训练好的网络,用到新的网络里面

加载预训练网络

1.原先已经训练好一个网络 AutoEncoder_FC()

2.首先加载该网络,读取其存储的参数

3.设置一个参数集

cnnpre = AutoEncoder_FC()
cnnpre.load_state_dict(torch.load('autoencoder_FC.pkl')['state_dict'])
cnnpre_dict =cnnpre.state_dict()

加载新网络

1.设置新的网络

2.设置新网络参数集

cnn= AutoEncoder()
cnn_dict = cnn.state_dict()

更新新网络参数

1.将两个参数集比对,存在的网络参数保留

2.使用保留下的参数更新新网络参数集

3.加载新网络参数集到新网络中

cnnpre_dict = {k: v for k, v in cnnpre_dict.items() if k in cnn_dict}
cnn_dict.update(cnnpre_dict)
cnn.load_state_dict(cnn_dict)

以上这篇pytorch 预训练层的使用方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python迭代器模块itertools使用原理解析

这篇文章主要介绍了Python迭代器模块itertools使用原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 介绍 今天介绍...

python网页请求urllib2模块简单封装代码

对python网页请求模块urllib2进行简单的封装。 例子: 复制代码 代码如下:#!/usr/bin/python#coding: utf-8import base64import...

Python使用type关键字创建类步骤详解

Python使用type关键字创建类步骤详解

Python使用type关键字创建类 打开命令行窗口,输入python,进入python交互环境 python 一般创建类使用class关键字即可,测试命令如下: class Co...

Python StringIO模块实现在内存缓冲区中读写数据

模块是用类编写的,只有一个StringIO类,所以它的可用方法都在类中。 此类中的大部分函数都与对文件的操作方法类似。 例: 复制代码 代码如下: #coding=gbk  ...

Python列表和元组的定义与使用操作示例

Python列表和元组的定义与使用操作示例

本文实例讲述了Python列表和元组的定义与使用操作。分享给大家供大家参考,具体如下: #coding=utf8 print ''''' 可以将列表和元组当成普通的“数组”,它能保存...