PyTorch 解决Dataset和Dataloader遇到的问题

yipeiwu_com5年前Python基础

今天在使用PyTorch中Dataset遇到了一个问题。先看代码

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])

结果运行时报错:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at /opt/conda/conda-bld/pytorch_1522182087074/work/torch/lib/TH/generic/THTensorMath.c:2897

Google了一下发现是这样的:读入的图片有些是灰度图(1个通道),绝大多数是RGB图片(3通道),也有些是带透明度的(4通道)

。这导致在读入后最后一个维度(通道数)不一致(可能是1、3或者4)。

Dataloader在制作batch data时,tensor的shape必须一样,就报了这个错误。解决的方法是:img = img.convert(“RGB”)。完

整代码如下:

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])

以上这篇PyTorch 解决Dataset和Dataloader遇到的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

深入理解Python中变量赋值的问题

前言 在Python中变量名规则与其他大多数高级语言一样,都是受C语言影响的,另外变量名是大小写敏感的。 Python是动态类型语言,也就是说不需要预先声明变量类型,变量的类型和值在赋值...

python导入csv文件出现SyntaxError问题分析

背景 np.loadtxt()用于从文本加载数据。 文本文件中的每一行必须含有相同的数据。 *** loadtxt(fname,dtype=<class'float'>,co...

python3实现指定目录下文件sha256及文件大小统计

python3实现指定目录下文件sha256及文件大小统计

有时会统计某个目录下有哪些文件,每个文件的sha256及文件大小等相关信息,这里用python3写了个脚本用来实现此功能,此脚本可跨平台,同时支持windows和linux,脚本(get...

tensorflow1.0学习之模型的保存与恢复(Saver)

将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。 模型保存,先要创建一个Saver对象:如...

python实现顺序表的简单代码

python实现顺序表的简单代码

 顺序表即线性表的顺序存储结构。它是通过一组地址连续的存储单元对线性表中的数据进行存储的,相邻的两个元素在物理位置上也是相邻的。比如,第1个元素是存储在线性表的起始位置LOC(...