热门标签 | HotTags
当前位置:  开发笔记 > 运维 > 正文

随机森林和GBDT的学习

前言提到森林,就不得不联想到树,因为正是一棵棵的树构成了庞大的森林,而在本篇文章中的树,指的就是DecisionTree-----决策树。随机森林就是一棵棵决策树的组合,也就是说随机森林boosting+决策树,这样就好理解多了吧,再来说说GBDT,GBDT全称是Gradie

前言 提到森林,就不得不联想到树,因为正是一棵棵的树构成了庞大的森林,而在本篇文章中的树,指的就是Decision Tree-----决策树。随机森林就是一棵棵决策树的组合,也就是说随机森林=boosting+决策树,这样就好理解多了吧,再来说说GBDT,GBDT全称是Gradie

前言

提到森林,就不得不联想到树,因为正是一棵棵的树构成了庞大的森林,而在本篇文章中的”树“,指的就是Decision Tree-----决策树。随机森林就是一棵棵决策树的组合,也就是说随机森林=boosting+决策树,这样就好理解多了吧,再来说说GBDT,GBDT全称是Gradient Boosting Decision Tree,就是梯度提升决策树,与随机森林的思想很像,但是比随机森林稍稍的难一点,当然效果相对于前者而言,也会好许多。由于本人才疏学浅,本文只会详细讲述Random Forest算法的部分,至于GBDT我会给出一小段篇幅做介绍引导,读者能够如果有兴趣的话,可以自行学习。

随机森林算法

决策树

要想理解随机森林算法,就不得不提决策树,什么是决策树,如何构造决策树,简单的回答就是数据的分类以树形结构的方式所展现,每个子分支都代表着不同的分类情况,比如下面的这个图所示:

\

当然决策树的每个节点分支不一定是三元的,可以有2个或者更多。分类的终止条件为,没有可以再拿来分类的属性条件或者说分到的数据的分类已经完全一致的情况。决策树分类的标准和依据是什么呢,下面介绍主要的2种划分标准。

1、信息增益。这是ID3算法系列所用的方法,C4.5算法在这上面做了少许的改进,用信息增益率来作为划分的标准,可以稍稍减小数据过于拟合的缺点。

2、基尼指数。这是CART分类回归树所用的方法。也是类似于信息增益的一个定义,最终都是根据数据划分后的纯度来做比较,这个纯度,你也可以理解为熵的变化,当然我们所希望的情况就是分类后数据的纯度更纯,也就是说,前后划分分类之后的熵的差越大越好。不过CART算法比较好的一点是树构造好后,还有剪枝的操作,剪枝操作的种类就比较多了,我之前在实现CART算法时用的是代价复杂度的剪枝方法。

这2种决策算法在我之前的博文中已经有所提及,不理解的可以点击我的ID3系列算法介绍和我的CART分类回归树算法

Boosting

原本不打算将Boosting单独拉出来讲的,后来想想还是有很多内容可谈的。Boosting本身不是一种算法,他更应该说是一种思想,首先对数据构造n个弱分类器,最后通过组合n个弱分类器对于某个数据的判断结果作为最终的分类结果,就变成了一个强分类器,效果自然要好过单一分类器的分类效果。他可以理解为是一种提升算法,举一个比较常见的Boosting思想的算法AdaBoost,他在训练每个弱分类器的时候,提高了对于之前分错数据的权重值,最终能够组成一批相互互补的分类器集合。详细可以查看我的AdaBoost算法学习

OK,2个重要的概念都已经介绍完毕,终于可以介绍主角Random Forest的出现了,正如前言中所说Random Forest=Decision Trees + Boosting,这里的每个弱分类器就是一个决策树了,不过这里的决策树都是二叉树,就是只有2个孩子分支,自然我立刻想到的做法就是用CART算法来构建,因为人家算法就是二元分支的。随机算法,随机算法,当然重在随机2个字上面,下面是2个方面体现了随机性。对于数据样本的采集量,比如我数据由100条,我可以每次随机取出其中的20条,作为我构造决策树的源数据,采取又放回的方式,并不是第一次抽到的数据,第二次不能重复,第二随机性体现在对于数据属性的随机采集,比如一行数据总共有10个特征属性,我每次随机采用其中的4个。正是由于对于数据的行压缩和列压缩,使得数据的随机性得以保证,就很难出现之前的数据过拟合的问题了,也就不需要在决策树最后进行剪枝操作了,这个是与一般的CART算法所不同的,尤其需要注意。

下面是随机森林算法的构造过程:

1、通过给定的原始数据,选出其中部分数据进行决策树的构造,数据选取是”有放回“的过程,我在这里用的是CART分类回归树。

2、随机森林构造完成之后,给定一组测试数据,使得每个分类器对其结果分类进行评估,最后取评估结果的众数最为最终结果。

算法非常的好理解,在Boosting算法和决策树之上做了一个集成,下面给出算法的实现,很多资料上只有大篇幅的理论,我还是希望能带给大家一点实在的东西。

随机算法的实现

输入数据(之前决策树算法时用过的)input.txt:

Rid Age Income Student CreditRating BuysComputer
1 Youth High No Fair No
2 Youth High No Excellent No
3 MiddleAged High No Fair Yes
4 Senior Medium No Fair Yes
5 Senior Low Yes Fair Yes
6 Senior Low Yes Excellent No
7 MiddleAged Low Yes Excellent Yes
8 Youth Medium No Fair No
9 Youth Low Yes Fair Yes
10 Senior Medium Yes Fair Yes
11 Youth Medium Yes Excellent Yes
12 MiddleAged Medium No Excellent Yes
13 MiddleAged High Yes Fair Yes
14 Senior Medium No Excellent No

树节点类TreeNode.java:

package DataMining_RandomForest;

import java.util.ArrayList;

/**
 * 回归分类树节点
 * 
 * @author lyq
 * 
 */
public class TreeNode {
	// 节点属性名字
	private String attrName;
	// 节点索引标号
	private int nodeIndex;
	//包含的叶子节点数
	private int leafNum;
	// 节点误差率
	private double alpha;
	// 父亲分类属性值
	private String parentAttrValue;
	// 孩子节点
	private TreeNode[] childAttrNode;
	// 数据记录索引
	private ArrayList dataIndex;

	public String getAttrName() {
		return attrName;
	}

	public void setAttrName(String attrName) {
		this.attrName = attrName;
	}

	public int getNodeIndex() {
		return nodeIndex;
	}

	public void setNodeIndex(int nodeIndex) {
		this.nodeIndex = nodeIndex;
	}

	public double getAlpha() {
		return alpha;
	}

	public void setAlpha(double alpha) {
		this.alpha = alpha;
	}

	public String getParentAttrValue() {
		return parentAttrValue;
	}

	public void setParentAttrValue(String parentAttrValue) {
		this.parentAttrValue = parentAttrValue;
	}

	public TreeNode[] getChildAttrNode() {
		return childAttrNode;
	}

	public void setChildAttrNode(TreeNode[] childAttrNode) {
		this.childAttrNode = childAttrNode;
	}

	public ArrayList getDataIndex() {
		return dataIndex;
	}

	public void setDataIndex(ArrayList dataIndex) {
		this.dataIndex = dataIndex;
	}

	public int getLeafNum() {
		return leafNum;
	}

	public void setLeafNum(int leafNum) {
		this.leafNum = leafNum;
	}
	
	
	
}
决策树类DecisionTree.java:

package DataMining_RandomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * 决策树
 * 
 * @author lyq
 * 
 */
public class DecisionTree {
	// 树的根节点
	TreeNode rootNode;
	// 数据的属性列名称
	String[] featureNames;
	// 这棵树所包含的数据
	ArrayList datas;
	// 决策树构造的的工具类
	CARTTool tool;

	public DecisionTree(ArrayList datas) {
		this.datas = datas;
		this.featureNames = datas.get(0);

		tool = new CARTTool(datas);
		// 通过CART工具类进行决策树的构建,并返回树的根节点
		rootNode = tool.startBuildingTree();
	}

	/**
	 * 根据给定的数据特征描述进行类别的判断
	 * 
	 * @param features
	 * @return
	 */
	public String decideClassType(String features) {
		String classType = "";
		// 查询属性组
		String[] queryFeatures;
		// 在本决策树中对应的查询的属性值描述
		ArrayList featureStrs;

		featureStrs = new ArrayList<>();
		queryFeatures = features.split(",");

		String[] array;
		for (String name : featureNames) {
			for (String featureValue : queryFeatures) {
				array = featureValue.split("=");
				// 将对应的属性值加入到列表中
				if (array[0].equals(name)) {
					featureStrs.add(array);
				}
			}
		}

		// 开始从根据节点往下递归搜索
		classType = recusiveSearchClassType(rootNode, featureStrs);

		return classType;
	}

	/**
	 * 递归搜索树,查询属性的分类类别
	 * 
	 * @param node
	 *            当前搜索到的节点
	 * @param remainFeatures
	 *            剩余未判断的属性
	 * @return
	 */
	private String recusiveSearchClassType(TreeNode node,
			ArrayList remainFeatures) {
		String classType = null;

		// 如果节点包含了数据的id索引,说明已经分类到底了
		if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
			classType = judgeClassType(node.getDataIndex());

			return classType;
		}

		// 取出剩余属性中的一个匹配属性作为当前的判断属性名称
		String[] currentFeature = null;
		for (String[] featureValue : remainFeatures) {
			if (node.getAttrName().equals(featureValue[0])) {
				currentFeature = featureValue;
				break;
			}
		}

		for (TreeNode childNode : node.getChildAttrNode()) {
			// 寻找子节点中属于此属性值的分支
			if (childNode.getParentAttrValue().equals(currentFeature[1])) {
				remainFeatures.remove(currentFeature);
				classType = recusiveSearchClassType(childNode, remainFeatures);

				// 如果找到了分类结果,则直接挑出循环
				break;
			}else{
				//进行第二种情况的判断加上!符号的情况
				String value = childNode.getParentAttrValue();
				
				if(value.charAt(0) == &#39;!&#39;){
					//去掉第一个!字符
					value = value.substring(1, value.length());
					
					if(!value.equals(currentFeature[1])){
						remainFeatures.remove(currentFeature);
						classType = recusiveSearchClassType(childNode, remainFeatures);

						break;
					}
				}
			}
		}

		return classType;
	}

	/**
	 * 根据得到的数据行分类进行类别的决策
	 * 
	 * @param dataIndex
	 *            根据分类的数据索引号
	 * @return
	 */
	public String judgeClassType(ArrayList dataIndex) {
		// 结果类型值
		String resultClassType = "";
		String classType = "";
		int count = 0;
		int temp = 0;
		Map type2Num = new HashMap();

		for (String index : dataIndex) {
			temp = Integer.parseInt(index);
			// 取最后一列的决策类别数据
			classType = datas.get(temp)[featureNames.length - 1];

			if (type2Num.containsKey(classType)) {
				// 如果类别已经存在,则使其计数加1
				count = type2Num.get(classType);
				count++;
			} else {
				count = 1;
			}

			type2Num.put(classType, count);
		}

		// 选出其中类别支持计数最多的一个类别值
		count = -1;
		for (Map.Entry entry : type2Num.entrySet()) {
			if ((int) entry.getValue() > count) {
				count = (int) entry.getValue();
				resultClassType = (String) entry.getKey();
			}
		}

		return resultClassType;
	}
}
随机森林算法工具类RandomForestTool.java:

package DataMining_RandomForest;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

/**
 * 随机森林算法工具类
 * 
 * @author lyq
 * 
 */
public class RandomForestTool {
	// 测试数据文件地址
	private String filePath;
	// 决策树的样本占总数的占比率
	private double sampleNumRatio;
	// 样本数据的采集特征数量占总特征的比例
	private double featureNumRatio;
	// 决策树的采样样本数
	private int sampleNum;
	// 样本数据的采集采样特征数
	private int featureNum;
	// 随机森林中的决策树的数目,等于总的数据数/用于构造每棵树的数据的数量
	private int treeNum;
	// 随机数产生器
	private Random random;
	// 样本数据列属性名称行
	private String[] featureNames;
	// 原始的总的数据
	private ArrayList totalDatas;
	// 决策树森林
	private ArrayList decisionForest;

	public RandomForestTool(String filePath, double sampleNumRatio,
			double featureNumRatio) {
		this.filePath = filePath;
		this.sampleNumRatio = sampleNumRatio;
		this.featureNumRatio = featureNumRatio;

		readDataFile();
	}

	/**
	 * 从文件中读取数据
	 */
	private void readDataFile() {
		File file = new File(filePath);
		ArrayList dataArray = new ArrayList();

		try {
			BufferedReader in = new BufferedReader(new FileReader(file));
			String str;
			String[] tempArray;
			while ((str = in.readLine()) != null) {
				tempArray = str.split(" ");
				dataArray.add(tempArray);
			}
			in.close();
		} catch (IOException e) {
			e.getStackTrace();
		}

		totalDatas = dataArray;
		featureNames = totalDatas.get(0);
		sampleNum = (int) ((totalDatas.size() - 1) * sampleNumRatio);
		//算属性数量的时候需要去掉id属性和决策属性,用条件属性计算
		featureNum = (int) ((featureNames.length -2) * featureNumRatio);
		// 算数量的时候需要去掉首行属性名称行
		treeNum = (totalDatas.size() - 1) / sampleNum;
	}

	/**
	 * 产生决策树
	 */
	private DecisionTree produceDecisionTree() {
		int temp = 0;
		DecisionTree tree;
		String[] tempData;
		//采样数据的随机行号组
		ArrayList sampleRandomNum;
		//采样属性特征的随机列号组
		ArrayList featureRandomNum;
		ArrayList datas;
		
		sampleRandomNum = new ArrayList<>();
		featureRandomNum = new ArrayList<>();
		datas = new ArrayList<>();
		
		for(int i=0; i 0){
				array[0] = temp + "";
			}
			
			temp++;
		}
		
		tree = new DecisionTree(datas);
		
		return tree;
	}

	/**
	 * 构造随机森林
	 */
	public void constructRandomTree() {
		DecisionTree tree;
		random = new Random();
		decisiOnForest= new ArrayList<>();

		System.out.println("下面是随机森林中的决策树:");
		// 构造决策树加入森林中
		for (int i = 0; i  type2Num = new HashMap();

		for (DecisionTree tree : decisionForest) {
			classType = tree.decideClassType(features);
			if (type2Num.containsKey(classType)) {
				// 如果类别已经存在,则使其计数加1
				count = type2Num.get(classType);
				count++;
			} else {
				count = 1;
			}

			type2Num.put(classType, count);
		}

		// 选出其中类别支持计数最多的一个类别值
		count = -1;
		for (Map.Entry entry : type2Num.entrySet()) {
			if ((int) entry.getValue() > count) {
				count = (int) entry.getValue();
				resultClassType = (String) entry.getKey();
			}
		}

		return resultClassType;
	}
}
CART算法工具类CARTTool.java:

package DataMining_RandomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Queue;

/**
 * CART分类回归树算法工具类
 * 
 * @author lyq
 * 
 */
public class CARTTool {
	// 类标号的值类型
	private final String YES = "Yes";
	private final String NO = "No";

	// 所有属性的类型总数,在这里就是data源数据的列数
	private int attrNum;
	private String filePath;
	// 初始源数据,用一个二维字符数组存放模仿表格数据
	private String[][] data;
	// 数据的属性行的名字
	private String[] attrNames;
	// 每个属性的值所有类型
	private HashMap> attrValue;

	public CARTTool(ArrayList dataArray) {
		attrValue = new HashMap<>();
		readData(dataArray);
	}

	/**
	 * 根据随机选取的样本数据进行初始化
	 * @param dataArray
	 * 已经读入的样本数据
	 */
	public void readData(ArrayList dataArray) {
		data = new String[dataArray.size()][];
		dataArray.toArray(data);
		attrNum = data[0].length;
		attrNames = data[0];
	}

	/**
	 * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
	 */
	public void initAttrValue() {
		ArrayList tempValues;

		// 按照列的方式,从左往右找
		for (int j = 1; j ();
			for (int i = 1; i  valueTypes = attrValue.get(attrName);
		// 属于此属性值的实例数
		HashMap belOngNum= new HashMap<>();

		for (String string : valueTypes) {
			// 重新计数的时候,数字归0
			tempNum = 0;
			// 按列从左往右遍历属性
			for (int j = 1; j  remainAttr,
			boolean beLongParentValue) {
		// 属性划分值
		String valueType = "";
		// 划分属性名称
		String spiltAttrName = "";
		double minGini = Integer.MAX_VALUE;
		double tempGini = 0;
		// 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值
		String[] giniArray;

		if (beLongParentValue) {
			node.setParentAttrValue(parentAttrValue);
		} else {
			node.setParentAttrValue("!" + parentAttrValue);
		}

		if (remainAttr.size() == 0) {
			if (remainData.length > 1) {
				ArrayList indexArray = new ArrayList<>();
				for (int i = 1; i  indexArray = new ArrayList<>();
			for (int k = 1; k  rAttr = new ArrayList<>();
				for (String str : remainAttr) {
					rAttr.add(str);
				}
				buildDecisionTree(childNode[i], valueType, rData, rAttr,
						bArray[i]);
			} else {
				String pAtr = (bArray[i] ? valueType : "!" + valueType);
				childNode[i].setParentAttrValue(pAtr);
				childNode[i].setDataIndex(indexArray);
			}
		}

		node.setChildAttrNode(childNode);
	}

	/**
	 * 属性划分完毕,进行数据的移除
	 * 
	 * @param srcData
	 *            源数据
	 * @param attrName
	 *            划分的属性名称
	 * @param valueType
	 *            属性的值类型
	 * @parame beLongValue 分类是否属于此值类型
	 */
	private String[][] removeData(String[][] srcData, String attrName,
			String valueType, boolean beLongValue) {
		String[][] desDataArray;
		ArrayList desData = new ArrayList<>();
		// 待删除数据
		ArrayList selectData = new ArrayList<>();
		selectData.add(attrNames);

		// 数组数据转化到列表中,方便移除
		for (int i = 0; i  remainAttr = new ArrayList<>();
		// 添加属性,除了最后一个类标号属性
		for (int i = 1; i  0) {
			System.out.print(node.getParentAttrValue());
		} else {
			System.out.print("--");
		}
		System.out.print("--");

		if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
			String i = node.getDataIndex().get(0);
			System.out.print("【" + node.getNodeIndex() + "】类别:"
					+ data[Integer.parseInt(i)][attrNames.length - 1]);
			System.out.print("[");
			for (String index : node.getDataIndex()) {
				System.out.print(index + ", ");
			}
			System.out.print("]");
		} else {
			// 递归显示子节点
			System.out.print("【" + node.getNodeIndex() + ":"
					+ node.getAttrName() + "】");
			if (node.getChildAttrNode() != null) {
				for (TreeNode childNode : node.getChildAttrNode()) {
					showDecisionTree(childNode, 2 * blankNum);
				}
			} else {
				System.out.print("【  Child Null】");
			}
		}
	}

	/**
	 * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝
	 * 
	 * @param node
	 *            开始的时候传入的是根节点
	 * @param index
	 *            开始的索引号,从1开始
	 * @param ifCutNode
	 *            是否需要剪枝
	 */
	private void setIndexAndAlpah(TreeNode node, int index, boolean ifCutNode) {
		TreeNode tempNode;
		// 最小误差代价节点,即将被剪枝的节点
		TreeNode minAlphaNode = null;
		double minAlpah = Integer.MAX_VALUE;
		Queue nodeQueue = new LinkedList();

		nodeQueue.add(node);
		while (nodeQueue.size() > 0) {
			index++;
			// 从队列头部获取首个节点
			tempNode = nodeQueue.poll();
			tempNode.setNodeIndex(index);
			if (tempNode.getChildAttrNode() != null) {
				for (TreeNode childNode : tempNode.getChildAttrNode()) {
					nodeQueue.add(childNode);
				}
				computeAlpha(tempNode);
				if (tempNode.getAlpha()  minAlphaNode.getLeafNum()) {
						minAlphaNode = tempNode;
					}
				}
			}
		}

		if (ifCutNode) {
			// 进行树的剪枝,让其左右孩子节点为null
			minAlphaNode.setChildAttrNode(null);
		}
	}

	/**
	 * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝
	 * 
	 * @param node
	 *            待计算的非叶子节点
	 */
	private void computeAlpha(TreeNode node) {
		double rt = 0;
		double Rt = 0;
		double alpha = 0;
		// 当前节点的数据总数
		int sumNum = 0;
		// 最少的偏差数
		int minNum = 0;

		ArrayList dataIndex;
		ArrayList leafNodes = new ArrayList<>();

		addLeafNode(node, leafNodes);
		node.setLeafNum(leafNodes.size());
		for (TreeNode attrNode : leafNodes) {
			dataIndex = attrNode.getDataIndex();

			int num = 0;
			sumNum += dataIndex.size();
			for (String s : dataIndex) {
				// 统计分类数据中的正负实例数
				if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {
					num++;
				}
			}
			minNum += num;

			// 取小数量的值部分
			if (1.0 * num / dataIndex.size() > 0.5) {
				num = dataIndex.size() - num;
			}

			rt += (1.0 * num / (data.length - 1));
		}
		
		//同样取出少偏差的那部分
		if (1.0 * minNum / sumNum > 0.5) {
			minNum = sumNum - minNum;
		}

		Rt = 1.0 * minNum / (data.length - 1);
		alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);
		node.setAlpha(alpha);
	}

	/**
	 * 筛选出节点所包含的叶子节点数
	 * 
	 * @param node
	 *            待筛选节点
	 * @param leafNode
	 *            叶子节点列表容器
	 */
	private void addLeafNode(TreeNode node, ArrayList leafNode) {
		ArrayList dataIndex;

		if (node.getChildAttrNode() != null) {
			for (TreeNode childNode : node.getChildAttrNode()) {
				dataIndex = childNode.getDataIndex();
				if (dataIndex != null && dataIndex.size() > 0) {
					// 说明此节点为叶子节点
					leafNode.add(childNode);
				} else {
					// 如果还是非叶子节点则继续递归调用
					addLeafNode(childNode, leafNode);
				}
			}
		}
	}

}
测试类Client.java:
package DataMining_RandomForest;

import java.text.MessageFormat;

/**
 * 随机森林算法测试场景
 * 
 * @author lyq
 * 
 */
public class Client {
	public static void main(String[] args) {
		String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
		String queryStr = "Age=Youth,Income=Low,Student=No,CreditRating=Fair";
		String resultClassType = "";
		// 决策树的样本占总数的占比率
		double sampleNumRatio = 0.4;
		// 样本数据的采集特征数量占总特征的比例
		double featureNumRatio = 0.5;

		RandomForestTool tool = new RandomForestTool(filePath, sampleNumRatio,
				featureNumRatio);
		tool.constructRandomTree();

		resultClassType = tool.judgeClassType(queryStr);

		System.out.println();
		System.out
				.println(MessageFormat.format(
						"查询属性描述{0},预测的分类结果为BuysCompute:{1}", queryStr,
						resultClassType));
	}
}

算法的输出

下面是随机森林中的决策树:

决策树1

    --!--【1:Income】
        --Medium--【2】类别:Yes[1, 2, ]
        --!Medium--【3:Student】
                --No--【4】类别:No[3, 5, ]
                --!No--【5】类别:Yes[4, ]
决策树2

    --!--【1:Student】
        --No--【2】类别:No[1, 3, ]
        --!No--【3】类别:Yes[2, 4, 5, ]
查询属性描述Age=Youth,Income=Low,Student=No,CreditRating=Fair,预测的分类结果为BuysCompute:No

输出的结果决策树建议从左往右看,从上往下,【】符号表示一个节点,---XX---表示属性值的划分,你就应该能看懂这棵树了,在console上想展示漂亮的树形效果的确很难。。。这里说一个算法的重大不足,数据太少,导致选择的样本数据不足,所选属性太少,,构造的决策树数量过少,自然分类的准确率不见得会有多准,博友只要能领会代码中所表达的算法的思想即可。

GBDT

下面来说说随机森林的兄弟算法GBDT,梯度提升决策树,他有很多的决策树,他也有组合的思想,但是他不是随机森林算法2,GBDT的关键在于Gradient Boosting,梯度提升。这个词语理解起来就不容易了。学术的描述,每一次建立模型是在之前建立模型的损失函数的梯度下降方向。GBDT的核心在于,每一棵树学的是之前所有树结论和的残差,这个残差你可以理解为与预测值的差值。举个例子:比如预测张三的年龄,张三的真实年龄18岁,第一棵树预测张的年龄12岁,此时残差为18-12=6岁,因此在第二棵树中,我们把张的年龄作为6岁去学习,如果预测成功了,则张的真实年龄就是A树和B树的结果预测值的和,但是如果B预测成了5岁,那么残差就变成了6-5=1岁,那么此时需要构建第三树对1岁做预测,后面一样的道理。每棵树都是对之前失败预测的一个补充,用公式的表达就是如下的这个样子:

\

F0在这里是初始值,Ti是一棵棵的决策树,不同的问题选择不同的损失函数和初始值。在阿里内部对于此算法的叫法为TreeLink。所以下次听到什么Treelink算法了指的就是梯度提升树算法,其实我在这里省略了很大篇幅的数学推导过程,再加上自己还不是专家,无法彻底解释清数学的部分,所以就没有提及,希望以后有时间可以深入学习此方面的知识。


推荐阅读
  • 云原生边缘计算之KubeEdge简介及功能特点
    本文介绍了云原生边缘计算中的KubeEdge系统,该系统是一个开源系统,用于将容器化应用程序编排功能扩展到Edge的主机。它基于Kubernetes构建,并为网络应用程序提供基础架构支持。同时,KubeEdge具有离线模式、基于Kubernetes的节点、群集、应用程序和设备管理、资源优化等特点。此外,KubeEdge还支持跨平台工作,在私有、公共和混合云中都可以运行。同时,KubeEdge还提供数据管理和数据分析管道引擎的支持。最后,本文还介绍了KubeEdge系统生成证书的方法。 ... [详细]
  • 本文介绍了Java的集合及其实现类,包括数据结构、抽象类和具体实现类的关系,详细介绍了List接口及其实现类ArrayList的基本操作和特点。文章通过提供相关参考文档和链接,帮助读者更好地理解和使用Java的集合类。 ... [详细]
  • 阿里Treebased Deep Match(TDM) 学习笔记及技术发展回顾
    本文介绍了阿里Treebased Deep Match(TDM)的学习笔记,同时回顾了工业界技术发展的几代演进。从基于统计的启发式规则方法到基于内积模型的向量检索方法,再到引入复杂深度学习模型的下一代匹配技术。文章详细解释了基于统计的启发式规则方法和基于内积模型的向量检索方法的原理和应用,并介绍了TDM的背景和优势。最后,文章提到了向量距离和基于向量聚类的索引结构对于加速匹配效率的作用。本文对于理解TDM的学习过程和了解匹配技术的发展具有重要意义。 ... [详细]
  • 标题: ... [详细]
  • 单点登录原理及实现方案详解
    本文详细介绍了单点登录的原理及实现方案,其中包括共享Session的方式,以及基于Redis的Session共享方案。同时,还分享了作者在应用环境中所遇到的问题和经验,希望对读者有所帮助。 ... [详细]
  • 本文介绍了在Docker容器技术中限制容器对CPU的使用的方法,包括使用-c参数设置容器的内存限额,以及通过设置工作线程数量来充分利用CPU资源。同时,还介绍了容器权重分配的情况,以及如何通过top命令查看容器在CPU资源紧张情况下的使用情况。 ... [详细]
  • 集合的遍历方式及其局限性
    本文介绍了Java中集合的遍历方式,重点介绍了for-each语句的用法和优势。同时指出了for-each语句无法引用数组或集合的索引的局限性。通过示例代码展示了for-each语句的使用方法,并提供了改写为for语句版本的方法。 ... [详细]
  • Python SQLAlchemy库的使用方法详解
    本文详细介绍了Python中使用SQLAlchemy库的方法。首先对SQLAlchemy进行了简介,包括其定义、适用的数据库类型等。然后讨论了SQLAlchemy提供的两种主要使用模式,即SQL表达式语言和ORM。针对不同的需求,给出了选择哪种模式的建议。最后,介绍了连接数据库的方法,包括创建SQLAlchemy引擎和执行SQL语句的接口。 ... [详细]
  • position属性absolute与relative的区别和用法详解
    本文详细解读了CSS中的position属性absolute和relative的区别和用法。通过解释绝对定位和相对定位的含义,以及配合TOP、RIGHT、BOTTOM、LEFT进行定位的方式,说明了它们的特性和能够实现的效果。同时指出了在网页居中时使用Absolute可能会出错的原因,即以浏览器左上角为原始点进行定位,不会随着分辨率的变化而变化位置。最后总结了一些使用这两个属性的技巧。 ... [详细]
  • 开发笔记:Docker 上安装启动 MySQL
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了Docker上安装启动MySQL相关的知识,希望对你有一定的参考价值。 ... [详细]
  • Oracle优化新常态的五大禁止及其性能隐患
    本文介绍了Oracle优化新常态中的五大禁止措施,包括禁止外键、禁止视图、禁止触发器、禁止存储过程和禁止JOB,并分析了这些禁止措施可能带来的性能隐患。文章还讨论了这些禁止措施在C/S架构和B/S架构中的不同应用情况,并提出了解决方案。 ... [详细]
  • Spring常用注解(绝对经典),全靠这份Java知识点PDF大全
    本文介绍了Spring常用注解和注入bean的注解,包括@Bean、@Autowired、@Inject等,同时提供了一个Java知识点PDF大全的资源链接。其中详细介绍了ColorFactoryBean的使用,以及@Autowired和@Inject的区别和用法。此外,还提到了@Required属性的配置和使用。 ... [详细]
  • 本文介绍了Java的公式汇总及相关知识,包括定义变量的语法格式、类型转换公式、三元表达式、定义新的实例的格式、引用类型的方法以及数组静态初始化等内容。希望对读者有一定的参考价值。 ... [详细]
  • 本文讨论了微软的STL容器类是否线程安全。根据MSDN的回答,STL容器类包括vector、deque、list、queue、stack、priority_queue、valarray、map、hash_map、multimap、hash_multimap、set、hash_set、multiset、hash_multiset、basic_string和bitset。对于单个对象来说,多个线程同时读取是安全的。但如果一个线程正在写入一个对象,那么所有的读写操作都需要进行同步。 ... [详细]
  • 本文介绍了一种图片处理应用,通过固定容器来实现缩略图的功能。该方法可以实现等比例缩略、扩容填充和裁剪等操作。详细的实现步骤和代码示例在正文中给出。 ... [详细]
author-avatar
mobiledu2502908043
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有