热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

python分层抽样

importnumpyasnpimportpandasaspdPATH_DESUserslinxianliDesktopdfpd.read_excel(PATH_DES+工作簿1.
import numpy as np
import pandas as pd

PATH_DES = '/Users/linxianli/Desktop/'
df = pd.read_excel(PATH_DES + '工作簿1.xlsx')

df.head()
python 分层抽样
# 使用 sklearn 进行分层抽样
from sklearn.model_selection import train_test_split

# data['TYPE']是在data中的某一个属性列
X_train, X_test, y_train, y_test = train_test_split(df,df['TYPE'], test_size=0.2, stratify=df['TYPE']) # test_size 测试集占比

print(X_train.shape)
print(X_test.shape)
'''
(885, 4)
(222, 4)
'''


# 普通方法进行分层抽样
test = pd.DataFrame()              # 划分出的test集合
train = pd.DataFrame()             # 剩余的train集合
tags = df['TYPE'].unique().tolist() # 按照该标签进行等比例抽取

for tag in tags:
    # 随机选取0.2的数据
    data = df[(df['TYPE'] == tag)]
    sample = data.sample(int(0.2*len(data)))
    sample_index = sample.index
    
    # 剩余数据
    all_index = data.index
    residue_index = all_index.difference(sample_index) # 去除sample之后剩余的数据
    residue = data.loc[residue_index]  # 这里要使用.loc而非.iloc
    
    # 保存
    test = pd.concat([test, sample], ignore_index=True)
    train = pd.concat([train, residue], ignore_index=True)

print(test.shape)
print(train.shape)
'''
(221, 4)
(886, 4)
'''

 


推荐阅读
author-avatar
小辛牛牛123牛牛小辛321
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有