关于pytorch处理类别不平衡的问题

yipeiwu_com6年前Python基础

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python、java等哪一门编程语言适合人工智能?

python、java等哪一门编程语言适合人工智能?

谷歌的AI击败了一位围棋大师,是一种衡量人工智能突然的快速发展的方式,也揭示了这些技术如何发展而来和将来可以如何发展。 人工智能是一种未来性的技术,目前正在致力于研究自己的一套工具。一...

浅析PEP572: 海象运算符

现在已经是Python 3.8的最后一个alpha版本,接着就是本月底要发布的的3.8.0 beta 1了。按规定,3.8已经不会再添加(修改)功能了,之前非常有争议的PEP 572的实...

python实现m3u8格式转换为mp4视频格式

python实现m3u8格式转换为mp4视频格式

开发动机:最近用手机QQ浏览器下载了一些视频,视频越来越多,占用了手机内存,于是想把下载的视频传到电脑上保存,可后来发现这些视频都是m3u8格式的,且这个格式的视频都切成了碎片,存在电脑...

Python探索之静态方法和类方法的区别详解

面相对象程序设计中,类方法和静态方法是经常用到的两个术语。 逻辑上讲:类方法是只能由类名调用;静态方法可以由类名或对象名进行调用。 python staticmethod and cla...

使用Pyinstaller的最新踩坑实战记录

前言 将py编译成可执行文件需要使用PyInstaller,之前给大家介绍了关于利用PyInstaller将python程序.py转为.exe的方法,在开始本文之前推荐大家可以先看下这篇...