pytorch 自定义数据集加载方法

yipeiwu_com5年前Python基础

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) – 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size – batch的大小(可选项,默认值为1)。

shuffle – 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler – 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers – 表示读取样本的线程数, 0表示只有主线程。

collate_fn – 合并一个样本列表称为一个batch。

pin_memory – 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) – 设置是否丢弃最后一个不完整的batch,默认为False。

timeout – 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python去除所有html标签的方法

本文实例讲述了python去除所有html标签的方法。分享给大家供大家参考。具体分析如下: 这段代码可以用于去除文本里的字符串标签,不包括标签里面的内容 import re html...

详解Python多线程Selenium跨浏览器测试

详解Python多线程Selenium跨浏览器测试

前言 在web测试中,不可避免的一个测试就是浏览器兼容性测试,在没有自动化测试前,我们总是苦逼的在一台或多台机器上安装N种浏览器,然后手工在不同的浏览器上验证主业务流程和...

Python实现的选择排序算法原理与用法实例分析

Python实现的选择排序算法原理与用法实例分析

本文实例讲述了Python实现的选择排序算法。分享给大家供大家参考,具体如下: 选择排序(Selection sort)是一种简单直观的排序算法。它的工作原理是每一次从待排序的数据元素中...

Sanic框架配置操作分析

本文实例讲述了Sanic框架配置操作。分享给大家供大家参考,具体如下: 简介 Sanic是一个类似Flask的Python 3.5+ Web服务器,它的写入速度非常快。除了Flask之外...

python设置随机种子实例讲解

对于原生的random模块 import random random.seed(1) 如果不设置,则python根据系统时间自己定一个。 也可以自己根据时间定一个随机种子,如:...