pytorch加载自定义网络权重的实现

yipeiwu_com6年前Python基础

在将自定义的网络权重加载到网络中时,报错:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

我们一步一步分析。

模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')

(1)查看获取模型权重的源码:

pytorch源码:net.state_dict()

def state_dict(self, destination=None, prefix='', keep_vars=False):
  r"""Returns a dictionary containing a whole state of the module.

  Both parameters and persistent buffers (e.g. running averages) are
  included. Keys are corresponding parameter and buffer names.

  Returns:
    dict:
      a dictionary containing a whole state of the module

  Example::

    >>> module.state_dict().keys()
    ['bias', 'weight']

  """

将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!

(2)查看保存模型权重的源码:

pytorch源码:torch.save()

def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
  """Saves an object to a disk file.

  See also: :ref:`recommend-saving-models`

  Args:
    obj: saved object
    f: a file-like object (has to implement write and flush) or a string
      containing a file name
    pickle_module: module used for pickling metadata and objects
    pickle_protocol: can be specified to override the default protocol

  .. warning::
    If you are using Python 2, torch.save does NOT support StringIO.StringIO
    as a valid file-like object. This is because the write method should return
    the number of bytes written; StringIO.write() does not do this.

    Please use something like io.BytesIO instead.

函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。

解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()

#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))

解决方法很简单,主要记录解决思路。

以上这篇pytorch加载自定义网络权重的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python基础教程之类class定义使用方法

面对对象(oop)中的对象,是一个非常重要的知识点,我们可以把它简单看做是数据以及由存取、操作这些数据的方法所组成的一个集合。我们在学习函数(function)之后,知道了如果重用代码,...

Python图像处理实现两幅图像合成一幅图像的方法【测试可用】

Python图像处理实现两幅图像合成一幅图像的方法【测试可用】

本文实例讲述了Python图像处理实现两幅图像合成一幅图像的方法。分享给大家供大家参考,具体如下: 将两幅图像合成一幅图像,是图像处理中常用的一种操作,python图像处理库PIL中提供...

使用 Python 实现微信群友统计器的思路详解

使用 Python 实现微信群友统计器的思路详解

基于微信可以做很多有意思的练手项目,看了这张速查表你就会发现,可以做的事情超过你的想象。 有一次我想要统计微信群里哪些同学在北京,但发现直接问是很难得到准确结果的…… 这时候不如运用...

python数据结构之二叉树的统计与转换实例

python数据结构之二叉树的统计与转换实例

一、获取二叉树的深度就是二叉树最后的层次,如下图: 实现代码:复制代码 代码如下:def getheight(self):     &n...

详解appium+python 启动一个app步骤

详解appium+python 启动一个app步骤

询问度娘搭好appium和python环境,开启移动app自动化的探索(基于Android),首先来记录下如何启动待测的app吧! 如何启动APP?1.获取包名;2.获取launcher...