基于pytorch的保存和加载模型参数的方法

yipeiwu_com6年前Python基础

当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了。

保存和加载模型参数有两种方式:

方式一:

torch.save(net.state_dict(),path):

功能:保存训练完的网络的各层参数(即weights和bias)

其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)

net2.load_state_dict(torch.load(path)):

功能:加载保存到path中的各层参数到神经网络

注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数

方式二:

torch.save(net,path):

功能:保存训练完的整个网络模型(不止weights和bias)

net2=torch.load(path):

功能:加载保存到path中的整个神经网络

说明:官方推荐方式一,原因自然是保存的内容少,速度会更快。

以上这篇基于pytorch的保存和加载模型参数的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python切片工具pillow用法示例

本文实例讲述了Python切片工具pillow用法。分享给大家供大家参考,具体如下: 切片:使用切片将源图像分成许多的功能区域 因为要对图片进行切片裁剪,所以用到切片工具必不可少,在ub...

使用Python压缩和解压缩zip文件的教程

python 的 zipfile 提供了非常便捷的方法来压缩和解压 zip 文件。 例如,在py脚本所在目录中,有如下文件: 复制代码 代码如下:readability/readabil...

修改Python的pyxmpp2中的主循环使其提高性能

引子 之前clubot使用的pyxmpp2的默认mainloop也就是一个poll的主循环,但是clubot上线后资源占用非常厉害,使用strace跟踪发现clubot在不停的poll,...

Python开发微信公众平台的方法详解【基于weixin-knife】

本文实例讲述了Python开发微信公众平台的方法。分享给大家供大家参考,具体如下: 这两天将之前基于微信公众平台的代码重构了下,基础功能以库的方式提供,提供了demo使用的是django...

Python实现合并同一个文件夹下所有PDF文件的方法示例

Python实现合并同一个文件夹下所有PDF文件的方法示例

本文实例讲述了Python实现合并同一个文件夹下所有PDF文件的方法。分享给大家供大家参考,具体如下: 一、需求说明 下载了网易云课堂的吴恩达免费的深度学习的pdf文档,但是每一节是一个...