Pytorch之保存读取模型实例

yipeiwu_com5年前Python基础

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Django中利用filter与simple_tag为前端自定义函数的实现方法

前言 Django的模板引擎提供了一般性的功能函数,通过前端可以实现多数的代码逻辑功能,这里称之为一般性,是因为它仅支持大多数常见情况下的函数功能,例如if判断,ifequal对比返回值...

Python 调用Java实例详解

Python 调用Java实例详解 前言: Python 对服务器端编程不如Java 所以这方面可能要调用Java代码 前提: Linux 环境  1 安装 jpype1 安...

对于Python编程中一些重用与缩减的建议

返璞归真 许多流行的玩具都以这样一个概念为基础:简单的积木。这些简单的积木可通过多种方式组合在一起构造出全新的作品 —— 有时甚至完全令人出乎意料。这一概念同样适用于现实生活中的建筑领域...

一文秒懂python读写csv xml json文件各种骚操作

Python优越的灵活性和易用性使其成为最受欢迎的编程语言之一,尤其是对数据科学家而言。 这在很大程度上是因为使用Python处理大型数据集是很简单的一件事情。 如今,每家科技公司都在制...

10款最好的Web开发的 Python 框架

10款最好的Web开发的 Python 框架

  Python 是一门动态、面向对象语言。其最初就是作为一门面向对象语言设计的,并且在后期又加入了一些更高级的特性。除了语言本身的设计目的之外,Python标准 库也是值得大家称赞的,...