对python中数据集划分函数StratifiedShuffleSplit的使用详解

yipeiwu_com5年前Python基础

文章开始先讲下交叉验证,这个概念同样适用于这个划分函数

1.交叉验证(Cross-validation)

交叉验证是指在给定的建模样本中,拿出其中的大部分样本进行模型训练,生成模型,留小部分样本用刚建立的模型进行预测,并求这小部分样本的预测误差,记录它们的平方加和。这个过程一直进行,直到所有的样本都被预测了一次而且仅被预测一次,比较每组的预测误差,选取误差最小的那一组作为训练模型。

下图所示

python StratifiedShuffleSplit的使用

2.StratifiedShuffleSplit函数的使用

官方文档

用法:

from sklearn.model_selection import StratifiedShuffleSplit
StratifiedShuffleSplit(n_splits=10,test_size=None,train_size=None, random_state=None)

2.1 参数说明

参数 n_splits是将训练数据分成train/test对的组数,可根据需要进行设置,默认为10

参数test_size和train_size是用来设置train/test对中train和test所占的比例。例如:

1.提供10个数据num进行训练和测试集划分

2.设置train_size=0.8 test_size=0.2

3.train_num=num*train_size=8 test_num=num*test_size=2

4.即10个数据,进行划分以后8个是训练数据,2个是测试数据

注*:train_num≥2,test_num≥2 ;test_size+train_size可以小于1*

参数 random_state控制是将样本随机打乱

2.2 函数作用描述

1.其产生指定数量的独立的train/test数据集划分数据集划分成n组。

2.首先将样本随机打乱,然后根据设置参数划分出train/test对。

3.其创建的每一组划分将保证每组类比比例相同。即第一组训练数据类别比例为2:1,则后面每组类别都满足这个比例

2.3 具体实现

from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4],
 [1, 2],[3, 4], [1, 2], [3, 4]])#训练数据集8*2
y = np.array([0, 0, 1, 1,0,0,1,1])#类别数据集8*1

ss=StratifiedShuffleSplit(n_splits=5,test_size=0.25,train_size=0.75,random_state=0)#分成5组,测试比例为0.25,训练比例是0.75

for train_index, test_index in ss.split(X, y):
 print("TRAIN:", train_index, "TEST:", test_index)#获得索引值
 X_train, X_test = X[train_index], X[test_index]#训练集对应的值
 y_train, y_test = y[train_index], y[test_index]#类别集对应的值

运行结果:

python StratifiedShuffleSplit的使用

从结果看出,1.训练集是6个,测试集是2,与设置的所对应;2.五组中每组对应的类别比例相同

以上这篇对python中数据集划分函数StratifiedShuffleSplit的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python从入门到精通(DAY 3)

python从入门到精通(DAY 3)

要求:编写登陆接口 输入用户名密码 认证成功后显示欢迎信息 输错三次后锁定 针对此实例写了有二种类型的脚本,略有不同,具体如下: 帐号文件account.txt内容如下: sam 12...

Java分治归并排序算法实例详解

Java分治归并排序算法实例详解

本文实例讲述了Java分治归并排序算法。分享给大家供大家参考,具体如下: 1、分治法 许多有用的算法在结构上是递归的:为了解决一个给定的问题,算法一次或多次递归地调用其自身以解决紧密相关...

解决python Markdown模块乱码的问题

解决python Markdown模块乱码的问题

有个需求需要把markdown转成html模块,查询了一下刚好有这个模块 安装 pip install amrkdown 安装完成直接转换并保存为html时,发现出现中文乱码的情况 用...

Python 实现大整数乘法算法的示例代码

我们平时接触的长乘法,按位相乘,是一种时间复杂度为 O(n ^ 2) 的算法。今天,我们来介绍一种时间复杂度为 O (n ^ log 3) 的大整数乘法(log 表示以 2 为底的对数)...

用uWSGI和Nginx部署Flask项目的方法示例

用uWSGI和Nginx部署Flask项目的方法示例

概况 在开发过程中,我们一般直接用Python命令直接运行Flask程序。这样的运行只适合我们开发,方便我们调试。一旦程序部署到线上,这样运行的Flask程序性能会比较低。可以采用uW...