热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

2.机器学习之为什么要用交叉验证

本文结构:什么是交叉验证法?为什么用交叉验证法?主要有哪些方法?优缺点?各方法应用举例?什么是交叉验证法?它的基本思想就是将原始数据(dataset)进行分组,一部分做为训练集来训

本文结构:

  • 什么是交叉验证法?
  • 为什么用交叉验证法?
  • 主要有哪些方法?优缺点?
  • 各方法应用举例?

什么是交叉验证法?

它的基本思想就是将原始数据(dataset)进行分组,一部分做为训练集来训练模型,另一部分做为测试集来评价模型。

为什么用交叉验证法?

  1. 交叉验证用于评估模型的预测性能,尤其是训练好的模型在新数据上的表现,可以在一定程度上减小过拟合。
  2. 还可以从有限的数据中获取尽可能多的有效信息。

主要有哪些方法?

1. 留出法 (holdout cross validation)

在机器学习任务中,拿到数据后,我们首先会将原始数据集分为三部分:训练集、验证集和测试集。 
训练集用于训练模型,验证集用于模型的参数选择配置,测试集对于模型来说是未知数据,用于评估模型的泛化能力。

《2.机器学习之 为什么要用交叉验证》

这个方法操作简单,只需随机把原始数据分为三组即可。 
不过如果只做一次分割,它对训练集、验证集和测试集的样本数比例,还有分割后数据的分布是否和原始数据集的分布相同等因素比较敏感,不同的划分会得到不同的最优模型,而且分成三个集合后,用于训练的数据更少了。

于是有了 2. k 折交叉验证(k-fold cross validation)加以改进:

《2.机器学习之 为什么要用交叉验证》

k 折交叉验证通过对 k 个不同分组训练的结果进行平均来减少方差,因此模型的性能对数据的划分就不那么敏感。

  • 第一步,不重复抽样将原始数据随机分为 k 份。
  • 第二步,每一次挑选其中 1 份作为测试集,剩余 k-1 份作为训练集用于模型训练。
  • 第三步,重复第二步 k 次,这样每个子集都有一次机会作为测试集,其余机会作为训练集。
  • 在每个训练集上训练后得到一个模型,
  • 用这个模型在相应的测试集上测试,计算并保存模型的评估指标,
  • 第四步,计算 k 组测试结果的平均值作为模型精度的估计,并作为当前 k 折交叉验证下模型的性能指标。

k 一般取 10, 
数据量小的时候,k 可以设大一点,这样训练集占整体比例就比较大,不过同时训练的模型个数也增多。 
数据量大的时候,k 可以设小一点。

当 k=m 即样本总数时,叫做 3. 留一法(Leave one out cross validation),每次的测试集都只有一个样本,要进行 m 次训练和预测。 
这个方法用于训练的数据只比整体数据集少了一个样本,因此最接近原始样本的分布。 
但是训练复杂度增加了,因为模型的数量与原始数据样本数量相同。 
一般在数据缺乏时使用。

此外:

  1. 多次 k 折交叉验证再求均值,例如:10 次 10 折交叉验证,以求更精确一点。
  2. 划分时有多种方法,例如对非平衡数据可以用分层采样,就是在每一份子集中都保持和原始数据集相同的类别比例。
  3. 模型训练过程的所有步骤,包括模型选择,特征选择等都是在单个折叠 fold 中独立执行的。

还有一种比较特殊的交叉验证方式,Bootstrapping: 通过自助采样法,即在含有 m 个样本的数据集中,每次随机挑选一个样本,再放回到数据集中,再随机挑选一个样本,这样有放回地进行抽样 m 次,组成了新的数据集作为训练集。

这里会有重复多次的样本,也会有一次都没有出现的样本,原数据集中大概有 36.8% 的样本不会出现在新组数据集中。

优点是训练集的样本总数和原数据集一样都是 m,并且仍有约 1/3 的数据不被训练而可以作为测试集。 
缺点是这样产生的训练集的数据分布和原数据集的不一样了,会引入估计偏差。 
此种方法不是很常用,除非数据量真的很少。

各方法应用举例?

1. 留出法 (holdout cross validation)

下面例子,一共有 150 条数据:

>>> import numpy as np
>>> from sklearn.model_selection import train_test_split
>>> from sklearn import datasets
>>> from sklearn import svm
>>> iris = datasets.load_iris()
>>> iris.data.shape, iris.target.shape
((150, 4), (150,))

  • 用 train_test_split 来随机划分数据集,其中 40% 用于测试集,有 60 条数据,60% 为训练集,有 90 条数据:

>>> X_train, X_test, y_train, y_test = train_test_split(
... iris.data, iris.target, test_size=0.4, random_state=0)
>>> X_train.shape, y_train.shape
((90, 4), (90,))
>>> X_test.shape, y_test.shape
((60, 4), (60,))

用 train 来训练,用 test 来评价模型的分数。

>>> clf = svm.SVC(kernel='linear', C=1).fit(X_train, y_train)
>>> clf.score(X_test, y_test)
0.96...

2. k 折交叉验证(k-fold cross validation)

最简单的方法是直接调用 cross_val_score,这里用了 5 折交叉验证:

>>> from sklearn.model_selection import cross_val_score
>>> clf = svm.SVC(kernel='linear', C=1)
>>> scores = cross_val_score(clf, iris.data, iris.target, cv=5)
>>> scores
array([ 0.96..., 1. ..., 0.96..., 0.96..., 1. ])

得到最后平均分为 0.98,以及它的 95% 置信区间:

>>> print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))
Accuracy: 0.98 (+/- 0.03)

我们可以直接看一下 K-fold 是怎样划分数据的: 
X 有四个数据,把它分成 2 折, 
结果中最后一个集合是测试集,前面的是训练集, 
每一行为 1 折:

  • 样的数据 X,我们看 LeaveOneOut 后是什么样子, 

那就是把它分成 4 折, 
结果中最后一个集合是测试集,只有一个元素,前面的是训练集, 
每一行为 1 折:

>>> from sklearn.model_selection import LeaveOneOut
>>> X = [1, 2, 3, 4]
>>> loo = LeaveOneOut()
>>> for train, test in loo.split(X):
... print("%s %s" % (train, test))
[1 2 3] [0]
[0 2 3] [1]
[0 1 3] [2]
[0 1 2] [3]

推荐阅读
  • 在本教程中,我们将看到如何使用FLASK制作第一个用于机器学习模型的RESTAPI。我们将从创建机器学习模型开始。然后,我们将看到使用Flask创建AP ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 本文介绍了在处理不规则数据时如何使用Python自动提取文本中的时间日期,包括使用dateutil.parser模块统一日期字符串格式和使用datefinder模块提取日期。同时,还介绍了一段使用正则表达式的代码,可以支持中文日期和一些特殊的时间识别,例如'2012年12月12日'、'3小时前'、'在2012/12/13哈哈'等。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
  • tcpdump 4.5.1 crash 深入分析
    tcpdump 4.5.1 crash 深入分析 ... [详细]
  • 前言:拿到一个案例,去分析:它该是做分类还是做回归,哪部分该做分类,哪部分该做回归,哪部分该做优化,它们的目标值分别是什么。再挑影响因素,哪些和分类有关的影响因素,哪些和回归有关的 ... [详细]
  • 基于词向量计算文本相似度1.测试数据:链接:https:pan.baidu.coms1fXJjcujAmAwTfsuTg2CbWA提取码:f4vx2.实验代码:imp ... [详细]
  • 语义分割系列3SegNet(pytorch实现)
    SegNet手稿最早是在2015年12月投出,和FCN属于同时期作品。稍晚于FCN,既然属于后来者,又是与FCN同属于语义分割网络 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • importjava.util.ArrayList;publicclassPageIndex{privateintpageSize;每页要显示的行privateintpageNum ... [详细]
  • 关键词:Golang, Cookie, 跟踪位置, net/http/cookiejar, package main, golang.org/x/net/publicsuffix, io/ioutil, log, net/http, net/http/cookiejar ... [详细]
author-avatar
别被风景迷了眼
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有