PyTorch加载预训练模型实例(pretrained)

yipeiwu_com6年前Python基础

使用预训练模型的代码如下:

# 加载预训练模型
 resNet50 = models.resnet50(pretrained=True)
 ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)

 # 读取参数
 pretrained_dict = resNet50.state_dict()
 model_dict = ResNet50.state_dict()

 # 将pretained_dict里不属于model_dict的键剔除掉
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

 # 更新现有的model_dict
 model_dict.update(pretrained_dict)

 # 加载真正需要的state_dict
 ResNet50.load_state_dict(model_dict)

以上这篇PyTorch加载预训练模型实例(pretrained)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python实现文件批量编码转换及注意事项

起因:大三做日本交换生期间在修一门C语言图像处理的编程课,在配套书籍的网站上下载了sample,但是由于我用的ubuntu18.04系统默认用utf-8编码,而文件源码是Shift_JI...

Python数据类型之Tuple元组实例详解

本文实例讲述了Python数据类型之Tuple元组。分享给大家供大家参考,具体如下: tuple元组 1.概述 本质上是一种有序的集合,和列表非常的相似,列表使用[]表示,元组使用()表...

Python实现的微信公众号群发图片与文本消息功能实例详解

本文实例讲述了Python实现的微信公众号群发图片与文本消息功能。分享给大家供大家参考,具体如下: 在微信公众号开发中,使用api都要附加access_token内容。因此,首先需要获取...

Django中使用session保持用户登陆连接的例子

使用session保持用户登陆连接 在 view 中 login() 视图函数里增加如下语句 不允许重复登录语句 if request.session.get('is_login',...

Python引用模块和查找模块路径

模块间相互独立相互引用是任何一种编程语言的基础能力。对于“模块”这个词在各种编程语言中或许是不同的,但我们可以简单认为一个程序文件是一个模块,文件里包含了类或者方法的定义。对于编译型的语...