解决Pytorch 训练与测试时爆显存(out of memory)的问题

yipeiwu_com5年前Python基础

Pytorch 训练时有时候会因为加载的东西过多而爆显存,有些时候这种情况还可以使用cuda的清理技术进行修整,当然如果模型实在太大,那也没办法。

使用torch.cuda.empty_cache()删除一些不需要的变量代码示例如下:

try:
  output = model(input)
except RuntimeError as exception:
  if "out of memory" in str(exception):
    print("WARNING: out of memory")
    if hasattr(torch.cuda, 'empty_cache'):
      torch.cuda.empty_cache()
  else:
    raise exception

测试的时候爆显存有可能是忘记设置no_grad, 示例代码如下:

  with torch.no_grad():
    for ii,(inputs,filelist) in tqdm(enumerate(test_loader), desc='predict'):
      if opt.use_gpu:
        inputs = inputs.cuda()
        if len(inputs.shape) < 4:
          inputs = inputs.unsqueeze(1)
 
      else:
        if len(inputs.shape) < 4:
          inputs = torch.transpose(inputs, 1, 2)
          inputs = inputs.unsqueeze(1)
 

以上这篇解决Pytorch 训练与测试时爆显存(out of memory)的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python计算库numpy进行方差/标准方差/样本标准方差/协方差的计算

使用numpy可以做很多事情,在这篇文章中简单介绍一下如何使用numpy进行方差/标准方差/样本标准方差/协方差的计算。 variance: 方差 方差(Variance)是概率论中最基...

python找出列表中大于某个阈值的数据段示例

python找出列表中大于某个阈值的数据段示例

该算法实现对列表中大于某个阈值(比如level=5)的连续数据段的提取,具体效果如下: 找出list里面大于5的连续数据段: list = [1,2,3,4,2,3,4,5,6,7,...

pyhton中__pycache__文件夹的产生与作用详解

用python编写了一个工程,但在第一次运行后,发现工程根目录下生成了一个__pycache__文件夹,里面是和py文件同名的各种以.cpython-35.pyc结尾的文件。cpytho...

python并发编程多进程之守护进程原理解析

守护进程 主进程创建子进程目的是:主进程有一个任务需要并发执行,那开启子进程帮我并发执行任务 主进程创建子进程,然后将该进程设置成守护自己的进程 关于守护进程需要强调两点: 其一:守护...

Python面向对象程序设计类的多态用法详解

本文实例讲述了Python面向对象程序设计类的多态用法。分享给大家供大家参考,具体如下: 多态 1、多态使用 一种事物的多种体现形式,举例:动物有很多种 注意: 继承是多态的前提 函数重...