解决Pytorch 加载训练好的模型 遇到的error问题

yipeiwu_com6年前Python基础

这是一个非常愚蠢的错误

debug的时候要好好看error信息

提醒自己切记好好对待error!切记!切记!

-----------------------分割线----------------

pytorch 已经非常友好了 保存模型和加载模型都只需要一条简单的命令

#保存整个网络和参数
torch.save(your_net, 'save_name.pkl')
#加载保存的模型
net = torch.load('save_name.pkl')

因为我比较懒我就想直接把整个网络都保存下来,然后在test文件中直接load一下不就好了?

就遭受了这样的错误。看错了error信息,把‘Net'看成‘net'。报错没有属性‘net'?这个不是我自己写的变量名么?

-----------------瞎捣鼓1h后(呵呵呵)----------------

回头看error,没有属性‘Net',Net???

我当下明白过来,应该是test文件中没有把它import进来,test中就没有任何关于Net的信息。我直接把定义的Net复制进了test.py,就顺利加载了训练好的模型。

但是我也有一个疑问,我理解的把整个模型保存难道不是把它的结构都保存下来了么?为什么还要再把这个网络import一次?来自python、pytorch、面向对象编程三次元小白的疑惑,先存个疑,搞懂了再来回答。

接下来试试只保存网络参数

#只保存网络参数
torch.save(your_net.state_dict(), 'save_name.pkl')
#加载保存的模型
net.load_state_dict(torch.load('save_name.pkl'))

保存网络参数

重新定义网络

报错

想死。。。

仔细看了报错信息,以我小白的理解,我感觉保存下来的可能只是单纯的数据,而不是一个对象(没有方法可以操作),或者该对象没有.copy()方法,所以没有办法进行.copy(),那肯定是保存哪里出错了。然后发现保存部分代码写错了,改成

print一下 net.state_dict和net.state_dict(),前者输出的是网络结构,后者才是网络的参数。

试着回答之前的问题,第二种保存模型的方法只保存了网络的参数(包括卷积层和全连接层每次的weight,bias),所以再加载模型的时候需要先定义网络无可厚非,就像训练时候定义网络那样定义就可以;而第一种保存整个网络的方法,保存了一个网络的实例(包括它的所有结构和参数),net是Net的一个实例,那为什么还要有Class Net的定义呢,还是回答不了。。

那就继续存疑,保持探究精神吧。。

以上这篇解决Pytorch 加载训练好的模型 遇到的error问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

对Python3之方法的覆盖与super函数详解

对Python3之方法的覆盖与super函数详解

#覆盖 覆盖:在继承关系中,子类实现了与基类同名的方法,在子类的实例调用该方法时,实例调用的是子类的覆盖版本。 通俗的讲,就是小明继承了他⑧的自行车,经过自己的改装,成了电动车,那么小明...

python+openCV利用摄像头实现人员活动检测

python+openCV利用摄像头实现人员活动检测

本文实例为大家分享了python+openCV利用摄像头实现人员活动检测的具体代码,供大家参考,具体内容如下 1.前言 最近在做个机器人比赛,其中一项要求是让机器人实现对是否有人员活动的...

python daemon守护进程实现

python daemon守护进程实现

假如写一段服务端程序,如果ctrl+c退出或者关闭终端,那么服务端程序就会退出,于是就想着让这个程序成为守护进程,像httpd一样,一直在后端运行,不会受终端影响。 守护进程英文为dae...

Python的Django框架中forms表单类的使用方法详解

Python的Django框架中forms表单类的使用方法详解

Form表单的功能 自动生成HTML表单元素 检查表单数据的合法性 如果验证错误,重新显示表单(数据不会重置) 数据类型转换(字符类型的数据转换成相应的Python类型...

python 画三维图像 曲面图和散点图的示例

用python画图很多是根据z=f(x,y)来画图的,本博文将三个对应的坐标点输入画图: 散点图: import matplotlib.pyplot as plt from mpl_...