pytorch sampler对数据进行采样的实现

yipeiwu_com5年前Python基础

PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。

下面举例说明。

from dataSet import *
dataset = DogCat('data/dogcat/', transform=transform)

from torch.utils.data import DataLoader
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]

print(weights)

from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
                num_samples=9,\
                replacement=True)
dataloader = DataLoader(dataset,
            batch_size=3,
            sampler=sampler)
for datas, labels in dataloader:
  print(labels.tolist())

输出:

[2, 2, 1, 1, 2, 1, 1, 2]
[1, 1, 0]
[1, 0, 0]
[0, 0, 1]

github 地址:

https://github.com/WebLearning17/CommonTool

以上这篇pytorch sampler对数据进行采样的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python实现SQL注入检测插件实例代码

Python实现SQL注入检测插件实例代码

扫描器需要实现的功能思维导图 爬虫编写思路 首先需要开发一个爬虫用于收集网站的链接,爬虫需要记录已经爬取的链接和待爬取的链接,并且去重,用 Python 的set()就可以解决,大概...

使用python3构建文件传输的方法

有时需要传输比较大的文件,通过聊天工具发送极其不方便,或者网络受限的情况下,只能另寻他法。用python就可以做一个简单的web服务,方便而且传输速率高。 步骤: 在cmd下,进入含有需...

python微元法计算函数曲线长度的方法

python微元法计算函数曲线长度的方法

计算曲线长度,根据线积分公式: ,令积分函数 f(x,y,z) 为1,即计算曲线的长度,将其微元化: 其中 根据此时便可在python编程实现,给出4个例子,代码中已有详细注释,不...

python能做什么 python的含义

python能做什么?是什么意思? Python是一种跨平台的计算机程序设计语言。是一种面向对象的动态类型语言,最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能...

Python的Flask框架中集成CKeditor富文本编辑器的教程

CKeditor是目前最优秀的可见即可得网页编辑器之一,它采用JavaScript编写。具备功能强大、配置容易、跨浏览器、支持多种编程语言、开源等特点。它非常流行,互联网上很容易找到相关...