基于pytorch的保存和加载模型参数的方法

yipeiwu_com5年前Python基础

当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了。

保存和加载模型参数有两种方式:

方式一:

torch.save(net.state_dict(),path):

功能:保存训练完的网络的各层参数(即weights和bias)

其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)

net2.load_state_dict(torch.load(path)):

功能:加载保存到path中的各层参数到神经网络

注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数

方式二:

torch.save(net,path):

功能:保存训练完的整个网络模型(不止weights和bias)

net2=torch.load(path):

功能:加载保存到path中的整个神经网络

说明:官方推荐方式一,原因自然是保存的内容少,速度会更快。

以上这篇基于pytorch的保存和加载模型参数的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python 定义给定初值或长度的list方法

python 定义给定初值或长度的list方法

1. 给定初值v,和长度l,定义list s 或者: 2. 产生一个数值递增list 2.1 从0开始以1递增 2.2 在[a,b)区间上以1递增 2.3 在[a,b)区间上以c...

Python获取命令实时输出-原样彩色输出并返回输出结果的示例

经试验显示效果不错。 #!/usr/bin/python3 # -*- coding: utf-8 -*- import os import subprocess # 与在命令窗...

python常用数据重复项处理方法

python常用数据重复项处理方法

在数据的处理过程中,一般都需要进行数据清洗工作,如数据集是否存在重复,是否存在缺失,数据是否具有完整性和一致性,数据中是否存在异常值等.发现诸如此类的问题都需要针对性地处理,下面我们一起...

Python中常用信号signal类型实例

本文研究的主要是Python中的Signal 信号的相关内容,具体如下。 常用信号类型 SIGINT 终止进程 中断进程,不可通过signal.signal()捕捉(相当于Ctrl...

PySide和PyQt加载ui文件的两种方法

本文实例为大家分享了PySide和PyQt加载ui文件的具体实现代码,供大家参考,具体内容如下 在用PySide或PyQt的时候,经常用到要将画好的ui文件导入到代码里使用,下面是两种调...