pytorch加载自定义网络权重的实现

yipeiwu_com6年前Python基础

在将自定义的网络权重加载到网络中时,报错:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

我们一步一步分析。

模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')

(1)查看获取模型权重的源码:

pytorch源码:net.state_dict()

def state_dict(self, destination=None, prefix='', keep_vars=False):
  r"""Returns a dictionary containing a whole state of the module.

  Both parameters and persistent buffers (e.g. running averages) are
  included. Keys are corresponding parameter and buffer names.

  Returns:
    dict:
      a dictionary containing a whole state of the module

  Example::

    >>> module.state_dict().keys()
    ['bias', 'weight']

  """

将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!

(2)查看保存模型权重的源码:

pytorch源码:torch.save()

def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
  """Saves an object to a disk file.

  See also: :ref:`recommend-saving-models`

  Args:
    obj: saved object
    f: a file-like object (has to implement write and flush) or a string
      containing a file name
    pickle_module: module used for pickling metadata and objects
    pickle_protocol: can be specified to override the default protocol

  .. warning::
    If you are using Python 2, torch.save does NOT support StringIO.StringIO
    as a valid file-like object. This is because the write method should return
    the number of bytes written; StringIO.write() does not do this.

    Please use something like io.BytesIO instead.

函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。

解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()

#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))

解决方法很简单,主要记录解决思路。

以上这篇pytorch加载自定义网络权重的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

在Python中用has_key()方法查找键是否存在的教程

 如果给定的键在字典可用,has_key()方法返回true,否则返回false。 语法 以下是has_key()方法的语法: dict.has_key(key) 参...

pytorch 在网络中添加可训练参数,修改预训练权重文件的方法

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。 假如我们现在有一个简单的两层感知机网络: # -*- coding: utf-8 -*- im...

基于Python检测动态物体颜色过程解析

基于Python检测动态物体颜色过程解析

本篇文章将通过图片对比的方法检查视频中的动态物体,并将其中会动的物体定位用cv2矩形框圈出来。本次项目可用于树莓派或者单片机追踪做一些思路参考。寻找动态物体也可以用来监控是否有人进入房间...

两个使用Python脚本操作文件的小示例分享

1这是一个创建一个文件,并在控制台写入行到新建的文件中. #!/usr/bin/env python 'makeTextFile.py -- create text file'...

Django框架 信号调度原理解析

Django中提供了“信号调度”,用于在框架执行操作时解耦。通俗来讲,就是一些动作发生的时候,信号允许特定的发送者去提醒一些接受者。 Django内置信号 Model signal...