PyTorch 解决Dataset和Dataloader遇到的问题

yipeiwu_com6年前Python基础

今天在使用PyTorch中Dataset遇到了一个问题。先看代码

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])

结果运行时报错:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at /opt/conda/conda-bld/pytorch_1522182087074/work/torch/lib/TH/generic/THTensorMath.c:2897

Google了一下发现是这样的:读入的图片有些是灰度图(1个通道),绝大多数是RGB图片(3通道),也有些是带透明度的(4通道)

。这导致在读入后最后一个维度(通道数)不一致(可能是1、3或者4)。

Dataloader在制作batch data时,tensor的shape必须一样,就报了这个错误。解决的方法是:img = img.convert(“RGB”)。完

整代码如下:

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])

以上这篇PyTorch 解决Dataset和Dataloader遇到的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python进阶之自定义对象实现切片功能

切片是 Python 中最迷人最强大最 Amazing 的语言特性(几乎没有之一),在《Python进阶:切片的误区与高级用法》中,我介绍了切片的基础用法、高级用法以及一些使用误区。这些...

对python 树状嵌套结构的实现思路详解

对python 树状嵌套结构的实现思路详解

原始数据 原始数据大致是这样子的: 每条数据中的四个数据分别是 当前节点名称,节点描述(指代一些需要的节点属性),源节点(即最顶层节点),父节点(当前节点上一层节点)。 datas...

使用python对文件中的单词进行提取的方法示例

使用python对文件中的单词进行提取的方法示例

由于需要使用一个纯单词组成的文件,在网上下载到了一个存放单词的文件,但是里面有中文的解释,那就需要做一下提取了。 文本的形式如下: 所见即所得,这个文本是有规律的,每个单词为一行,紧...

Python 记录日志的灵活性和可配置性介绍

Python 记录日志的灵活性和可配置性介绍

对一名开发者来说最糟糕的情况,莫过于要弄清楚一个不熟悉的应用为何不工作。有时候,你甚至不知道系统运行,是否跟原始设计一致。 在线运行的应用就是黑盒子,需要被跟踪监控。最简单也最重要的方式...

Pytorch在NLP中的简单应用详解

Pytorch在NLP中的简单应用详解

因为之前在项目中一直使用Tensorflow,最近需要处理NLP问题,对Pytorch框架还比较陌生,所以特地再学习一下pytorch在自然语言处理问题中的简单使用,这里做一个记录。 一...