PyTorch加载预训练模型实例(pretrained)

yipeiwu_com6年前Python基础

使用预训练模型的代码如下:

# 加载预训练模型
 resNet50 = models.resnet50(pretrained=True)
 ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)

 # 读取参数
 pretrained_dict = resNet50.state_dict()
 model_dict = ResNet50.state_dict()

 # 将pretained_dict里不属于model_dict的键剔除掉
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

 # 更新现有的model_dict
 model_dict.update(pretrained_dict)

 # 加载真正需要的state_dict
 ResNet50.load_state_dict(model_dict)

以上这篇PyTorch加载预训练模型实例(pretrained)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

对Python中DataFrame按照行遍历的方法

在做分类模型时候,需要在DataFrame中按照行获取数据以便于进行训练和测试。 import pandas as pd dict=[[1,2,3,4,5,6],[2,3,4,5,6...

python之信息加密题目详解

1.贴题 题目来自PythonTip 信息加密 给你个小写英文字符串a和一个非负数b(0<=b<26), 将a中的每个小写字符替换成字母表中比它大b的字母。这里将字母表...

Python解决两个整数相除只得到整数部分的实例

在python中进行两个整数相除的时候,在默认情况下都是只能够得到整数的值 解决方法: 1. 修改被除数的值为带小数点的形式即可得到浮点值 2.在文件头部引入 from __futu...

python 时间戳与格式化时间的转化实现代码

python 里面与时间有关的模块主要是 time 和 datetime 如果想获取系统当前时间戳:time.time() ,是一个float型的数据 获取系统当前的时间信息 : tim...

转换科学计数法的数值字符串为decimal类型的方法

在操作数据库时,需要将字符串转换成decimal类型。 代码如下: select cast('0.12' as decimal(18,2)); select convert(dec...