PyTorch 解决Dataset和Dataloader遇到的问题

yipeiwu_com6年前Python基础

今天在使用PyTorch中Dataset遇到了一个问题。先看代码

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])

结果运行时报错:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at /opt/conda/conda-bld/pytorch_1522182087074/work/torch/lib/TH/generic/THTensorMath.c:2897

Google了一下发现是这样的:读入的图片有些是灰度图(1个通道),绝大多数是RGB图片(3通道),也有些是带透明度的(4通道)

。这导致在读入后最后一个维度(通道数)不一致(可能是1、3或者4)。

Dataloader在制作batch data时,tensor的shape必须一样,就报了这个错误。解决的方法是:img = img.convert(“RGB”)。完

整代码如下:

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])

以上这篇PyTorch 解决Dataset和Dataloader遇到的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python处理PHP数组文本文件实例

需求: 对一个配置文件进行处理,拿出可用的字符来拼接,下面是原始文本,我们要得到这样的结果, 复制代码 代码如下: redis -h 127.0.0.1 -p 6379 | select...

pandas 数据结构之Series的使用方法

1. Series Series 是一个类数组的数据结构,同时带有标签(lable)或者说索引(index)。 1.1 下边生成一个最简单的Series对象,因为没有给Series指定索...

Python字典遍历操作实例小结

本文实例讲述了Python字典遍历操作。分享给大家供大家参考,具体如下: 1 遍历键值对 可以使用一个 for 循环以及方法 items() 来遍历这个字典的键值对。 dict =...

详解利用Python scipy.signal.filtfilt() 实现信号滤波

本文将以实战的形式基于scipy模块使用Python实现简单滤波处理,包括内容有1.低通滤波,2.高通滤波,3.带通滤波,4.带阻滤波器。具体的含义大家可以查阅大学课程,信号与系统。简单...

python验证码识别的示例代码

python验证码识别的示例代码

写爬虫有一个绕不过去的问题就是验证码,现在验证码分类大概有4种: 图像类 滑动类 点击类 语音类 今天先来看看图像类,这类验证码大多是数字、字母的组合,国内也有使用汉...