热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

实用:sklearn提取决策树规则代码(附python代码)

《老饼讲解机器学习》http:ml.bbbdata.comteach#107目录一.问题二.主要思路三.代码实例1.数据提取2.预测函数3.准确性测试一.问题在决策树

《老饼讲解机器学习》http://ml.bbbdata.com/teach#107


目录

一.问题

二.主要思路

三.代码实例

1.数据提取

2.预测函数

3.准确性测试



一.问题

在决策树模型建好之后,要提取规则布署到生产。

二.主要思路

只提取数据,在生产环境写出通用预测代码。新的模型只需替换数据即可。

备注:一般不弄成一系列的if else,写死代码不便于更换模型。

三.代码实例

1.数据提取

使用如下get_tree函数,将树数据提取成字典:

from sklearn import tree
import numpy as np
def get_tree(sk_tree):#--------------拷贝sklearn树模型关键信息--------------------children_left = sk_tree.tree_.children_left.copy() # 左节点编号children_right = sk_tree.tree_.children_right.copy() # 右节点编号feature = sk_tree.tree_.feature.copy() # 分割的变量threshold = sk_tree.tree_.threshold.copy() # 分割阈值impurity = sk_tree.tree_.impurity.copy() # 不纯度(gini)n_node_samples = sk_tree.tree_.n_node_samples.copy() # 样本个数value = sk_tree.tree_.value.copy() # 样本分布n_sample = value[0].sum() # 总样本个数node_num = len(children_left) # 节点个数depth = sk_tree.get_depth()# ------------补充节点父节点信息---------------------------parent = np.zeros(node_num).astype(int)parent[0] = -1branch_idx = np.where(children_left!=-1)[0]for i in branch_idx:parent[children_left[i]] = i parent[children_right[i]]= i #-------------存成字典----------------------------------------- tree = {'children_left':children_left,'children_right':children_right,'feature':feature,'threshold':threshold,'impurity':impurity,'n_node_samples':n_node_samples,'value':value,'depth':depth,'n_sample':n_sample,'node_num':node_num,'parent':parent}return tree

将训练好的模型sk_tree传入以上函数,转化成字典,保存成文件。

2.预测函数

在生产时使用如下tree_predict 函数预测(其它语言类似以下逻辑)。

import numpy as np
def tree_predict(tree,x):node_idx = 0t = 0while(t

3.准确性测试

from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
from get_tree import get_tree
from tree_pred import tree_predict#----------------数据准备----------------------------
iris = load_iris() # 加载数据
X = iris.data
y = iris.target
#---------------模型训练----------------------------------
clf = tree.DecisionTreeClassifier() # sk-learn的决策树模型
clf = clf.fit(X, y) # 用数据训练树模型构建()
#--------------将树提取成简单的字典--------------------------------
tree = get_tree(clf)
#-------------------------
#将tree持久化到服务器,服务器中用tree_predict进行预测即可
#-------------------------#------------测试函数的准确性-----------------------------
self_pred_y = np.zeros(len(y))
self_pred_prob = np.zeros((len(y),len(tree['value'][0][0])))
for i in range(X.shape[0]):pred_class,pred_prob = tree_predict(tree,X[i])self_pred_y[i] = pred_classself_pred_prob[i] = pred_prob
pred_y = clf.predict(X)
pred_prob = clf.predict_proba(X)
print("与sklearn预测结果差异个数(类别):",np.sum(pred_y != self_pred_y))
print("与sklearn预测结果差异个数(概率):",np.sum(pred_prob != self_pred_prob))

 测试结果:

与sklearn预测结果差异个数(类别): 0
与sklearn预测结果差异个数(概率): 0

相关文章

《深入浅出:决策树入门简介》

《一个简单的决策树分类例子》

《sklearn决策树结果可视化》

《sklearn决策树参数详解》


推荐阅读
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • 如何使用Java获取服务器硬件信息和磁盘负载率
    本文介绍了使用Java编程语言获取服务器硬件信息和磁盘负载率的方法。首先在远程服务器上搭建一个支持服务端语言的HTTP服务,并获取服务器的磁盘信息,并将结果输出。然后在本地使用JS编写一个AJAX脚本,远程请求服务端的程序,得到结果并展示给用户。其中还介绍了如何提取硬盘序列号的方法。 ... [详细]
  • http:my.oschina.netleejun2005blog136820刚看到群里又有同学在说HTTP协议下的Get请求参数长度是有大小限制的,最大不能超过XX ... [详细]
  • 在重复造轮子的情况下用ProxyServlet反向代理来减少工作量
    像不少公司内部不同团队都会自己研发自己工具产品,当各个产品逐渐成熟,到达了一定的发展瓶颈,同时每个产品都有着自己的入口,用户 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 怎么在PHP项目中实现一个HTTP断点续传功能发布时间:2021-01-1916:26:06来源:亿速云阅读:96作者:Le ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • Whatsthedifferencebetweento_aandto_ary?to_a和to_ary有什么区别? ... [详细]
  • 本文介绍了在rhel5.5操作系统下搭建网关+LAMP+postfix+dhcp的步骤和配置方法。通过配置dhcp自动分配ip、实现外网访问公司网站、内网收发邮件、内网上网以及SNAT转换等功能。详细介绍了安装dhcp和配置相关文件的步骤,并提供了相关的命令和配置示例。 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • 知识图谱——机器大脑中的知识库
    本文介绍了知识图谱在机器大脑中的应用,以及搜索引擎在知识图谱方面的发展。以谷歌知识图谱为例,说明了知识图谱的智能化特点。通过搜索引擎用户可以获取更加智能化的答案,如搜索关键词"Marie Curie",会得到居里夫人的详细信息以及与之相关的历史人物。知识图谱的出现引起了搜索引擎行业的变革,不仅美国的微软必应,中国的百度、搜狗等搜索引擎公司也纷纷推出了自己的知识图谱。 ... [详细]
  • 本文介绍了作者在开发过程中遇到的问题,即播放框架内容安全策略设置不起作用的错误。作者通过使用编译时依赖注入的方式解决了这个问题,并分享了解决方案。文章详细描述了问题的出现情况、错误输出内容以及解决方案的具体步骤。如果你也遇到了类似的问题,本文可能对你有一定的参考价值。 ... [详细]
  • 自动轮播,反转播放的ViewPagerAdapter的使用方法和效果展示
    本文介绍了如何使用自动轮播、反转播放的ViewPagerAdapter,并展示了其效果。该ViewPagerAdapter支持无限循环、触摸暂停、切换缩放等功能。同时提供了使用GIF.gif的示例和github地址。通过LoopFragmentPagerAdapter类的getActualCount、getActualItem和getActualPagerTitle方法可以实现自定义的循环效果和标题展示。 ... [详细]
  • 预备知识可参考我整理的博客Windows编程之线程:https:www.cnblogs.comZhuSenlinp16662075.htmlWindows编程之线程同步:https ... [详细]
author-avatar
丢失的面包树
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有