PyTorch加载预训练模型实例(pretrained)

yipeiwu_com6年前Python基础

使用预训练模型的代码如下:

# 加载预训练模型
 resNet50 = models.resnet50(pretrained=True)
 ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)

 # 读取参数
 pretrained_dict = resNet50.state_dict()
 model_dict = ResNet50.state_dict()

 # 将pretained_dict里不属于model_dict的键剔除掉
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

 # 更新现有的model_dict
 model_dict.update(pretrained_dict)

 # 加载真正需要的state_dict
 ResNet50.load_state_dict(model_dict)

以上这篇PyTorch加载预训练模型实例(pretrained)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python 中 Meta Classes详解

接触过 Django 的同学都应该十分熟悉它的 ORM 系统。对于 python 新手而言,这是一项几乎可以被称作“黑科技”的特性:只要你在models.py中随便定义一个Model的子...

python 如何将数据写入本地txt文本文件的实现方法

一、读写txt文件 1、打开txt文件 file_handle=open('1.txt',mode='w') 上述函数参数有(1.文件名,mode模式) mode模式有以下几种...

如何用Python来理一理红楼梦里的那些关系

如何用Python来理一理红楼梦里的那些关系

前言 今天,一起用 Python 来理一理红楼梦里的那些关系 不要问我为啥是红楼梦,而不是水浒三国或西游,因为我也鉴定的认为,红楼才是无可争议的中国古典小说只巅峰,且不接受反驳!而红楼...

python logging日志模块的详解

python logging日志模块的详解 日志级别 日志一共分成5个等级,从低到高分别是:DEBUG INFO WARNING ERROR CRITICAL。 DEBUG:详细的信...

Django Celery异步任务队列的实现

背景 在开发中,我们常常会遇到一些耗时任务,举个例子: 上传并解析一个 1w 条数据的 Excel 文件,最后持久化至数据库。 在我的程序中,这个任务耗时大约 6s,对于用户来说,...