PyTorch加载预训练模型实例(pretrained)

yipeiwu_com5年前Python基础

使用预训练模型的代码如下:

# 加载预训练模型
 resNet50 = models.resnet50(pretrained=True)
 ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)

 # 读取参数
 pretrained_dict = resNet50.state_dict()
 model_dict = ResNet50.state_dict()

 # 将pretained_dict里不属于model_dict的键剔除掉
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

 # 更新现有的model_dict
 model_dict.update(pretrained_dict)

 # 加载真正需要的state_dict
 ResNet50.load_state_dict(model_dict)

以上这篇PyTorch加载预训练模型实例(pretrained)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python字符遍历的艺术

比如,将一个字符串转换为一个字符数组: theList = list(theString) 同时,我们可以方便的通过for语句进行遍历: for c in theString: do_s...

Python中实现两个字典(dict)合并的方法

本文实例讲述了Python中实现两个字典(dict)合并的方法,分享给大家供大家参考。具体方法如下: 现有两个字典dict如下: dict1={1:[1,11,111],2:[2,2...

numpy.delete删除一列或多列的方法

基础介绍: numpy.delete numpy.delete(arr, obj, axis=None)[source] Return a new array with sub-a...

python调用百度REST API实现语音识别

目前,语音识别,即将语音内容转换为文字的技术已经比较成熟,遥想当时锤子发布会上展示的讯飞输入法语音识别,着实让讯飞火了一把。由于此类语音识别需要采集大量的样本,才能达到一定的准确度,个人...

python导入模块交叉引用的方法

实际项目中遇到python模块相互引用问题,查资料,终于算是弄明白了。 首先交叉引用或是相互引用,实际上就是导入循环,关于导入循环的详细说明,可见我摘自《python核心编程》第二版的摘...