基于pytorch的保存和加载模型参数的方法

yipeiwu_com6年前Python基础

当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了。

保存和加载模型参数有两种方式:

方式一:

torch.save(net.state_dict(),path):

功能:保存训练完的网络的各层参数(即weights和bias)

其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)

net2.load_state_dict(torch.load(path)):

功能:加载保存到path中的各层参数到神经网络

注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数

方式二:

torch.save(net,path):

功能:保存训练完的整个网络模型(不止weights和bias)

net2=torch.load(path):

功能:加载保存到path中的整个神经网络

说明:官方推荐方式一,原因自然是保存的内容少,速度会更快。

以上这篇基于pytorch的保存和加载模型参数的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Flask模板引擎之Jinja2语法介绍

Jinja是组成Flask的模板引擎。可能你还不太了解它是干嘛的,但你对下面这些百分号和大括号肯定不陌生: {% block body %} <ul> {% for...

详解python3中socket套接字的编码问题解决

一、TCP 1、tcp服务器创建 #创建服务器 from socket import * from time import ctime #导入ctime HOST = ''...

跟老齐学Python之关于循环的小伎俩

不是说while就不用,比如前面所列举而得那个猜数字游戏,在业务逻辑上,用while就更容易理解(当然是限于那个游戏的业务需要而言)。另外,在某些情况下,for也不是简单地把对象中的元素...

对python数据切割归并算法的实例讲解

当一个 .txt 文件的数据过于庞大,此时想要对数据进行排序就需要先将数据进行切割,然后通过归并排序,最终实现对整体数据的排序。要实现这个过程我们需要进行以下几步:获取总数据行数;根据行...

使用Python制作自动推送微信消息提醒的备忘录功能

使用Python制作自动推送微信消息提醒的备忘录功能

日常工作生活中,事情一多,就会忘记一些该做未做的事情。即使有时候把事情记录在了小本本上或者手机、电脑端备忘录上,也总会有查看不及时,导致错过的尴尬。如果有一款小工具,可以及时提醒,而不用...