pytorch sampler对数据进行采样的实现

yipeiwu_com6年前Python基础

PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。

下面举例说明。

from dataSet import *
dataset = DogCat('data/dogcat/', transform=transform)

from torch.utils.data import DataLoader
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]

print(weights)

from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
                num_samples=9,\
                replacement=True)
dataloader = DataLoader(dataset,
            batch_size=3,
            sampler=sampler)
for datas, labels in dataloader:
  print(labels.tolist())

输出:

[2, 2, 1, 1, 2, 1, 1, 2]
[1, 1, 0]
[1, 0, 0]
[0, 0, 1]

github 地址:

https://github.com/WebLearning17/CommonTool

以上这篇pytorch sampler对数据进行采样的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Django管理员账号和密码忘记的完美解决方法

Django管理员账号和密码忘记的完美解决方法

发现问题 看着Django的教程学习搭建网站,结果忘记第一次创建的账号和密码了。结果搭建成功以后,一直无法登陆到管理页面,进行不下去了。 如图所示: 在网上找了很多的方法都不行,最后使...

python模拟鼠标拖动操作的方法

python模拟鼠标拖动操作的方法

本文实例讲述了python模拟鼠标拖动操作的方法。分享给大家供大家参考。具体如下: pdf中的书签只有页码,准备把现有书签拖到一个目录中,然后添加自己页签。重复的拖动工作实在无趣,还是让...

请不要重复犯我在学习Python和Linux系统上的错误

请不要重复犯我在学习Python和Linux系统上的错误

本人已经在运维行业工作了将近十年,我最早接触Linux是在大二的样子,那时候只追求易懂,所以就选择了Ubuntu作为学习、使用的对象,它简单、易用、好操作、界面绚丽,对于想接触Linux...

python 实现矩阵上下/左右翻转,转置的示例

python中没有二维数组,用一个元素为list的list(matrix)保存矩阵,row为行数,col为列数 1. 上下翻转:只需要把每一行的list交换即可 for i in r...

python几种常用功能实现代码实例

这篇文章主要介绍了python几种常用功能实现代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 1、python 程序退出的几种...