关于pytorch处理类别不平衡的问题

yipeiwu_com6年前Python基础

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python科学画图代码分享

Python科学画图代码分享

Python画图主要用到matplotlib这个库。Matplotlib 是一个 Python 的 2D绘图库,它以各种硬拷贝格式和跨平台的交互式环境生成出版质量级别的图形。 这里有一本...

pytorch中的embedding词向量的使用方法

Embedding 词嵌入在 pytorch 中非常简单,只需要调用 torch.nn.Embedding(m, n) 就可以了,m 表示单词的总数目,n 表示词嵌入的维度,其实词嵌入就...

Python Pandas 箱线图的实现

Python Pandas 箱线图的实现

各国家用户消费分布 import numpy as np import pandas as pd import matplotlib.pyplot as plt data...

flask的orm框架SQLAlchemy查询实现解析

这篇文章主要介绍了flask的orm框架SQLAlchemy查询实现解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 一对多,多对多...

Django项目使用ckeditor详解(不使用admin)

Django项目使用ckeditor详解(不使用admin)

效果图: 1.安装django-ckeditor pip install django-ckeditor 如果需要上传图片或者文件,还需要安装pillow pip insta...