决策树的原理及简单应用

zoukeh

发布日期: 2021-03-17 08:16:52 浏览量: 179
评分:
star star star star star star star star star star
*转载请注明来自write-bug.com

决策树原理

今天放假,我是一个生活较为规律的人,想要选择今天的活动。今天是2021年x月x日,首先我起床,拉开窗帘打开天气app考察外界环境。假如温度高于10℃就选择出门,低于10℃就宅家。如果温度高于10℃,观察风力:如果风力低于3级,选择打羽毛球;高于3级则打篮球。如果温度低于10℃,观察爸妈是否在家:如果爸妈在家,选择学习;不在家则选择玩游戏。通过上述决策条件,给出某一天的数据,就可以预测出我在做什么。比如1月20日,温度为15℃,风力5级,爸妈在家,则可以预测该日我在打篮球。

这是一个简单的决策树例子。决策树方法是应用最广的归纳推理算法之一。其起源是概念学习系统,然后发展到ID3(处理离散属性)方法而为高潮,最后又演化为能处理连续属性的C4.5。有名的决策树方法还有CART和Assistant(处理缺失属性的数据,比如互联网数据)。决策树通过把实例(样本)从根节点排列到某个叶子节点来分类实例,叶子节点即为实例所属的分类。树上的每一个节点说明了对实例的某个属性的测试,并且该节点的每一个后继分支对应于该属性的一个可能值。前面提到的日常活动决策树的示意图如下:

对于一组n维特征数据,我们如何选取具有区分特性的特征来构建决策树呢?基本的决策树学习方法(ID3思想)是自顶向下,选取当前特征中最有信息性的特征来分割树,进而将样本最好地进行划分。为了找到这一最优划分,关键是需要找到最好的逻辑判断。一般来说,树越小则树的预测能力越强,若某一个子集上样本很少,还有分支,这说明属性不是很好,只对某些很少的样本有作用。我们要尽可能构建较小的决策树。

信息性的概念有一点抽象,我们可以理解为一个特征对于整个样本能够体现的信息量的多少。比如上面的例子中的温度条件,显然比风力条件的信息性更强。我们需要用一个具体的计算公式或函数来衡量它。

其中较常用的几个度量规则包括信息增益、不纯度等。我们在此重点介绍信息增益。

首先我们知道,熵是测量一个信号或者分布的不确定度的度量。若所有样本属于同一类,熵为0;如果各属于一半,等于1,熵取得最大值。在构建决策树时我们通过某个属性,应该要尽可能快的达到树的叶子节点,我们需要选择熵较小的特征。

熵的计算公式如下:

假如正负样本个数分别为9和5,熵值为:

  1. -(9/14)log2(9/14)- (5/14)log2(5/14)=0.94

当我们在评估特征的分类能力时,我们可以计算其使样本熵降低的幅度,也就是信息增益。

Sv是样本S中特征A的值为v的子集。第二项描述的是期望熵,就是每个子集的熵的加权和,权值为属于Sv的样例占原始样本S的比例。Gain(S,A)是由于给定特征A的值而得到的关于目标函数值的信息。如果选取某个特征对样本进行划分后,信息增益最大,那么我们就选择它。信息增益公式中第一项Entropy(S)又被称为样本的信息熵,第二项被称为条件熵。

基于这个选择依据,我们就可以开始构建一棵决策树啦。

决策树的简单应用

首先我们有如下的数据:

其中包括了20为病人的一些信息,医生会给病人开A、B、C、D四种药物。我们需要构建决策树来找出病人的身体信息和医生所开的药物之间的关系,从而可以用决策树为之后的病人自动开药。

我们使用python的sklearn包中tree模块构建决策树。话不多说上代码:

  1. from sklearn import tree
  2. import sklearn.model_selection as ms
  3. from sklearn.feature_extraction import DictVectorizer
  4. import numpy as np
  5. from sklearn import metrics
  6. import matplotlib.pyplot as plt
  7. plt.style.use('ggplot')
  8. data = [
  9. {'age': 33, 'sex': 'F', 'BP': 'high', 'cholesterol': 'high', 'Na': 0.66, 'K': 0.06, 'drug': 'A'},
  10. {'age': 77, 'sex': 'F', 'BP': 'high', 'cholesterol': 'normal', 'Na': 0.19, 'K': 0.03, 'drug': 'D'},
  11. {'age': 88, 'sex': 'M', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.80, 'K': 0.05, 'drug': 'B'},
  12. {'age': 39, 'sex': 'F', 'BP': 'low', 'cholesterol': 'normal', 'Na': 0.19, 'K': 0.02, 'drug': 'C'},
  13. {'age': 43, 'sex': 'M', 'BP': 'normal', 'cholesterol': 'high', 'Na': 0.36, 'K': 0.03, 'drug': 'D'},
  14. {'age': 82, 'sex': 'F', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.09, 'K': 0.09, 'drug': 'C'},
  15. {'age': 40, 'sex': 'M', 'BP': 'high', 'cholesterol': 'normal', 'Na': 0.89, 'K': 0.02, 'drug': 'A'},
  16. {'age': 88, 'sex': 'M', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.80, 'K': 0.05, 'drug': 'B'},
  17. {'age': 29, 'sex': 'F', 'BP': 'high', 'cholesterol': 'normal', 'Na': 0.35, 'K': 0.04, 'drug': 'D'},
  18. {'age': 53, 'sex': 'F', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.54, 'K': 0.06, 'drug': 'C'},
  19. {'age': 36, 'sex': 'F', 'BP': 'high', 'cholesterol': 'high', 'Na': 0.53, 'K': 0.05, 'drug': 'A'},
  20. {'age': 63, 'sex': 'M', 'BP': 'low', 'cholesterol': 'high', 'Na': 0.86, 'K': 0.09, 'drug': 'B'},
  21. {'age': 60, 'sex': 'M', 'BP': 'low', 'cholesterol': 'normal', 'Na': 0.66, 'K': 0.04, 'drug': 'C'},
  22. {'age': 55, 'sex': 'M', 'BP': 'high', 'cholesterol': 'high', 'Na': 0.82, 'K': 0.04, 'drug': 'B'},
  23. {'age': 35, 'sex': 'F', 'BP': 'normal', 'cholesterol': 'high', 'Na': 0.27, 'K': 0.03, 'drug': 'D'},
  24. {'age': 23, 'sex': 'F', 'BP': 'high', 'cholesterol': 'high', 'Na': 0.55, 'K': 0.08, 'drug': 'A'},
  25. {'age': 49, 'sex': 'F', 'BP': 'low', 'cholesterol': 'normal', 'Na': 0.27, 'K': 0.05, 'drug': 'C'},
  26. {'age': 27, 'sex': 'M', 'BP': 'normal', 'cholesterol': 'normal', 'Na': 0.77, 'K': 0.02, 'drug': 'B'},
  27. {'age': 51, 'sex': 'F', 'BP': 'low', 'cholesterol': 'high', 'Na': 0.20, 'K': 0.02, 'drug': 'D'},
  28. {'age': 38, 'sex': 'M', 'BP': 'high', 'cholesterol': 'normal', 'Na': 0.78, 'K': 0.05, 'drug': 'A'}
  29. ]
  30. sodium = [d['Na'] for d in data]
  31. potassium = [d['K'] for d in data]
  32. target=[d['drug'] for d in data]
  33. target=[ord(t)-65 for t in target] #ord('A')=65
  34. vec = DictVectorizer(sparse=False)
  35. data_pre = vec.fit_transform(data) #数据预处理,将所有特征转换为数值特征
  36. data_pre=np.array(data_pre,dtype=np.float32)
  37. target=np.array(target,dtype=np.float32)
  38. X_train, X_test, y_train, y_test = ms.train_test_split(data_pre, target, test_size=5, random_state=42)
  39. #sklearn
  40. dtc=tree.DecisionTreeClassifier()
  41. dtc.fit(X_train,y_train)
  42. print(dtc.score(X_train,y_train))
  43. print(dtc.score(X_test, y_test))
  44. with open("tree.dot", 'w') as f:
  45. f = tree.export_graphviz(dtc, out_file=f,
  46. feature_names=vec.get_feature_names(),
  47. class_names=['A', 'B', 'C', 'D'])

我们可以得到决策树在训练集和测试集上的正确率均为100%。在最后我们使用graphviz库,将决策树的具体结构进行了可视化。

首先我们在命令行中(本文使用conda命令行)输入:

  1. conda install graphviz

此时我们自动安装graphviz库,建议在翻墙后下载。

回到python中执行最后一段语句,将决策树导出为tree.dot文件。再到命令行中使用graphviz把它转成png图片:

  1. dot -Tpng tree.dot -o tree.png

转换后如下所示:

一个简单的决策树案例就介绍到这里。

引用

[1]M. Beyeler(2017) Machine Learning for OpenCV:Intelligent image processing with Python Packt Publishing Ltd, ISBN 978-178398028-4.

[2]《模式识别》西安电子科技大学出版社.

上传的附件

发送私信

2
文章数
0
评论数
最近文章
eject