热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

python程序实现rep后剪枝算法

背景在使用决策树模型时,如果训练集中的样本数很多,则会使得生成的决策树过于庞大,即分化出了很多的枝节。这时会产生过拟合问题,

背景

在使用决策树模型时,如果训练集中的样本数很多,则会使得生成的决策树过于庞大,即分化出了很多的枝节。这时会产生过拟合问题,也就是在模型在训练集上的表现效果良好,而在测试集的效果却很差。因此在生成一棵决策树之后,需要对它进行必要的剪枝,从而提高它的泛化能力。本文将讲述后剪枝算法——REP方法。


原理

剪枝是指将决策树的一些枝节去掉,将中间节点变成叶子节点,该叶子节点的预测值便是该分组训练样本yyy值的均值。剪枝算法分为预剪枝和后剪枝,预剪枝是在决策树生成的过程中同步进行,而后剪枝是在决策树生成完之后再剪枝。

REP方法也称为错误率降低剪枝,它是一类最基础、最简单的后剪枝算法,是其他剪枝算法的基础。主要过程是将训练集分为两个集合N1N_1N1N2N_2N2,可以称为训练集中的训练集和训练集中的验证集。N1N_1N1用来生成决策树,N2N_2N2用来验证剪枝前后的模型效果。具体是先用N1N_1N1来生成决策树,然后自底向上遍历所有中间节点,对于每个中间节点,比较剪枝前后的两棵决策树在验证集N2N_2N2上的效果,这个效果体现在N2N_2N2通过两棵决策树得到的预测值与原始实际值的误差平方和,若剪枝后的误差平方和更小,则对决策树进行剪枝,反之则不进行剪枝。


例子

假设通过对数据集N1N_1N1进行训练,得到了如下的决策树

现在要自底向上进行剪枝,对象是中间节点,对应到图中依次是节点5、2、3。对于节点5,先将它的左右枝8和9剪掉,得到剪枝前和剪枝后的两棵树

将数据集N2N_2N2的特征xxx分别代入这两棵树,得到两组预测值,然后通过比较两组数据的误差平方和来决策是否进行剪枝。之后再考虑节点2,最后考虑节点3。


程序实现


重新定义树结构

为了方便处理不同节点间的调用,CART回归树的树模型不再用字典进行存储,而改用自定义的类对象(参考leetcode中的树节点),每个节点可以通过成员变量来调用分裂出来的左右节点。

class TreeNode:def __init__(self, val, fea_name=None, fea_c=None):self.left = Noneself.right = Noneself.val = round(val,2)self.fea_name = fea_name self.fea_c = fea_c if fea_c is None else round(fea_c,2)

变量valvalval是指当前节点数据集的yyy值的均值,若当前节点是叶子节点,则该变量代表这种分支的预测值,若是中间节点,则可以表示为对该节点进行剪枝后的节点预测值。


生成剪枝后的子树

def sub_tree(tree, num): # 返回后序剪枝得到的子树stack = [(False, tree)]while stack:flag, t = stack.pop()if not t:continueif flag:if t.left or t.right:if num==0:t.left = Nonet.right = Nonereturn treeelse:num -= 1else:stack.append((True, t))stack.append((False, t.right))stack.append((False, t.left))return tree

采用后序遍历的方式来搜索中间节点,参数numnumnum是为了控制对应序号的中间节点,因为并不是剪去每个中间节点都能提高性能,通过numnumnum可以避开不想剪去的中间节点。


计算中间节点的个数

def mid_leaf_num(tree): if not tree or (not tree.left and not tree.right):return 0return 1 + mid_leaf_num(tree.left) + mid_leaf_num(tree.right)

效果比较函数

def ifmore(self, temp_tree, test_x, test_y):orig_ = []temp_ = []for i in range(len(test_x)):orig_.append(self.check(self.tree, test_x[i]))temp_.append(self.check(temp_tree, test_x[i]))orig_sum = sum(np.power(np.array(orig_)-test_y, 2))temp_sum = sum(np.power(np.array(temp_)-test_y, 2))if orig_sum>temp_sum: # and (orig_sum-temp_sum)/orig_sum>0.0001:self.tree = temp_treereturn Trueelse:return False

REP剪枝函数

def prune_tree(self, test_x, test_y):mid_num &#61; mid_leaf_num(self.tree)i &#61; 0while i<mid_num: temp_tree &#61; sub_tree(self.tree, i)if self.ifmore(temp_tree, test_x, test_y):i &#61; 0mid_num -&#61; 1else:i &#43;&#61; 1

实例化演示

X,y &#61; make_regression(n_samples&#61;1000, n_features&#61;4, noise&#61;0.1)
X_name &#61; np.array(list(&#39;abcd&#39;))
clf &#61; Tree_Regress()
train_x, test_x, train_y, test_y &#61; train_test_split(X, y, test_size&#61;0.30)
train_size &#61; len(train_x)//4
train_test_x, train_test_y &#61; train_x[:train_size], train_y[:train_size]
train_train_x, train_train_y &#61; train_x[train_size:], train_y[train_size:]
print(&#39;不剪枝&#39;)
clf.fit(train_x, train_y, X_name)
predict_y &#61; clf.predict(test_x)
pre_error &#61; sum(np.power(test_y-predict_y,2))
print(&#39;误差为&#xff1a;&#39;, pre_error,&#39; 节点数&#xff1a;&#39;, clf.node_num())print(&#39;有剪枝&#39;)
clf.fit(train_train_x, train_train_y, X_name)
clf.prune_tree(train_test_x, train_test_y)
predict_y &#61; clf.predict(test_x)
pre_error &#61; sum(np.power(test_y-predict_y,2))
print(&#39;误差为&#xff1a;&#39;, pre_error,&#39; 节点数&#xff1a;&#39;, clf.node_num())# 不剪枝
# 误差为&#xff1a; 9860.41223121167 节点数&#xff1a; 249
# 有剪枝
# 误差为&#xff1a; 6892.04066950184 节点数&#xff1a; 33

在对模型进行后剪枝之后&#xff0c;模型的泛化能力有所提升。


不足

REP方法虽然在一定程度上简化了决策树&#xff0c;提高了模型的性能&#xff0c;但在有些情况下反而会造成相反的结果&#xff0c;使得模型表现更差。我试了几个不同的数据集&#xff0c;发现效果其实不是很好&#xff0c;可能是这个方法考虑到的东西比较少&#xff0c;它单单考虑到了模型拟合的误差平方和&#xff0c;却没考虑生成的节点个数&#xff0c;粗暴地将影响模型性能的枝节都剪去&#xff0c;使得模型太过简单。另外&#xff0c;该方法的计算开销是很大的&#xff0c;需要遍历搜索两次中间节点。


----end----

推荐阅读
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 本文介绍了在处理不规则数据时如何使用Python自动提取文本中的时间日期,包括使用dateutil.parser模块统一日期字符串格式和使用datefinder模块提取日期。同时,还介绍了一段使用正则表达式的代码,可以支持中文日期和一些特殊的时间识别,例如'2012年12月12日'、'3小时前'、'在2012/12/13哈哈'等。 ... [详细]
  • Python爬虫中使用正则表达式的方法和注意事项
    本文介绍了在Python爬虫中使用正则表达式的方法和注意事项。首先解释了爬虫的四个主要步骤,并强调了正则表达式在数据处理中的重要性。然后详细介绍了正则表达式的概念和用法,包括检索、替换和过滤文本的功能。同时提到了re模块是Python内置的用于处理正则表达式的模块,并给出了使用正则表达式时需要注意的特殊字符转义和原始字符串的用法。通过本文的学习,读者可以掌握在Python爬虫中使用正则表达式的技巧和方法。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • 十大经典排序算法动图演示+Python实现
    本文介绍了十大经典排序算法的原理、演示和Python实现。排序算法分为内部排序和外部排序,常见的内部排序算法有插入排序、希尔排序、选择排序、冒泡排序、归并排序、快速排序、堆排序、基数排序等。文章还解释了时间复杂度和稳定性的概念,并提供了相关的名词解释。 ... [详细]
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • 本文讨论了Kotlin中扩展函数的一些惯用用法以及其合理性。作者认为在某些情况下,定义扩展函数没有意义,但官方的编码约定支持这种方式。文章还介绍了在类之外定义扩展函数的具体用法,并讨论了避免使用扩展函数的边缘情况。作者提出了对于扩展函数的合理性的质疑,并给出了自己的反驳。最后,文章强调了在编写Kotlin代码时可以自由地使用扩展函数的重要性。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • 本文整理了Java面试中常见的问题及相关概念的解析,包括HashMap中为什么重写equals还要重写hashcode、map的分类和常见情况、final关键字的用法、Synchronized和lock的区别、volatile的介绍、Syncronized锁的作用、构造函数和构造函数重载的概念、方法覆盖和方法重载的区别、反射获取和设置对象私有字段的值的方法、通过反射创建对象的方式以及内部类的详解。 ... [详细]
  • EPPlus绘制刻度线的方法及示例代码
    本文介绍了使用EPPlus绘制刻度线的方法,并提供了示例代码。通过ExcelPackage类和List对象,可以实现在Excel中绘制刻度线的功能。具体的方法和示例代码在文章中进行了详细的介绍和演示。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • HashMap的相关问题及其底层数据结构和操作流程
    本文介绍了关于HashMap的相关问题,包括其底层数据结构、JDK1.7和JDK1.8的差异、红黑树的使用、扩容和树化的条件、退化为链表的情况、索引的计算方法、hashcode和hash()方法的作用、数组容量的选择、Put方法的流程以及并发问题下的操作。文章还提到了扩容死链和数据错乱的问题,并探讨了key的设计要求。对于对Java面试中的HashMap问题感兴趣的读者,本文将为您提供一些有用的技术和经验。 ... [详细]
author-avatar
ociVyouzhangzh063_1fd2bf_633
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有