1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Jul 10 22:04:33 2015
4
+
5
+ @author: wepon
6
+ """
7
+
8
+ import numpy as np
9
+
10
+ class DecisionTree :
11
+ """决策树使用方法:
12
+
13
+ - 生成实例: clf = DecisionTrees(). 参数mode可选,ID3或C4.5,默认C4.5
14
+
15
+ - 训练,调用fit方法: clf.fit(X,y). X,y均为np.ndarray类型
16
+
17
+ - 预测,调用predict方法: clf.predict(X). X为np.ndarray类型
18
+
19
+ - 可视化决策树,调用showTree方法
20
+
21
+ """
22
+ def __init__ (self ,mode = 'C4.5' ):
23
+ self ._tree = None
24
+
25
+ if mode == 'C4.5' or mode == 'ID3' :
26
+ self ._mode = mode
27
+ else :
28
+ raise Exception ('mode should be C4.5 or ID3' )
29
+
30
+
31
+
32
+ def _calcEntropy (self ,y ):
33
+ """
34
+ 函数功能:计算熵
35
+ 参数y:数据集的标签
36
+ """
37
+ num = y .shape [0 ]
38
+ #统计y中不同label值的个数,并用字典labelCounts存储
39
+ labelCounts = {}
40
+ for label in y :
41
+ if label not in labelCounts .keys (): labelCounts [label ] = 0
42
+ labelCounts [label ] += 1
43
+ #计算熵
44
+ entropy = 0.0
45
+ for key in labelCounts :
46
+ prob = float (labelCounts [key ])/ num
47
+ entropy -= prob * np .log2 (prob )
48
+ return entropy
49
+
50
+
51
+
52
+ def _splitDataSet (self ,X ,y ,index ,value ):
53
+ """
54
+ 函数功能:返回数据集中特征下标为index,特征值等于value的子数据集
55
+ """
56
+ ret = []
57
+ featVec = X [:,index ]
58
+ X = X [:,[i for i in range (X .shape [1 ]) if i != index ]]
59
+ for i in range (len (featVec )):
60
+ if featVec [i ]== value :
61
+ ret .append (i )
62
+ return X [ret ,:],y [ret ]
63
+
64
+
65
+ def _chooseBestFeatureToSplit_ID3 (self ,X ,y ):
66
+ """ID3
67
+ 函数功能:对输入的数据集,选择最佳分割特征
68
+ 参数dataSet:数据集,最后一列为label
69
+ 主要变量说明:
70
+ numFeatures:特征个数
71
+ oldEntropy:原始数据集的熵
72
+ newEntropy:按某个特征分割数据集后的熵
73
+ infoGain:信息增益
74
+ bestInfoGain:记录最大的信息增益
75
+ bestFeatureIndex:信息增益最大时,所选择的分割特征的下标
76
+ """
77
+ numFeatures = X .shape [1 ]
78
+ oldEntropy = self ._calcEntropy (y )
79
+ bestInfoGain = 0.0
80
+ bestFeatureIndex = - 1
81
+ #对每个特征都计算一下infoGain,并用bestInfoGain记录最大的那个
82
+ for i in range (numFeatures ):
83
+ featList = X [:,i ]
84
+ uniqueVals = set (featList )
85
+ newEntropy = 0.0
86
+ #对第i个特征的各个value,得到各个子数据集,计算各个子数据集的熵,
87
+ #进一步地可以计算得到根据第i个特征分割原始数据集后的熵newEntropy
88
+ for value in uniqueVals :
89
+ sub_X ,sub_y = self ._splitDataSet (X ,y ,i ,value )
90
+ prob = len (sub_y )/ float (len (y ))
91
+ newEntropy += prob * self ._calcEntropy (sub_y )
92
+ #计算信息增益,根据信息增益选择最佳分割特征
93
+ infoGain = oldEntropy - newEntropy
94
+ if (infoGain > bestInfoGain ):
95
+ bestInfoGain = infoGain
96
+ bestFeatureIndex = i
97
+ return bestFeatureIndex
98
+
99
+ def _chooseBestFeatureToSplit_C45 (self ,X ,y ):
100
+ """C4.5
101
+ ID3算法计算的是信息增益,C4.5算法计算的是信息增益比,对上面ID3版本的函数稍作修改即可
102
+ """
103
+ numFeatures = X .shape [1 ]
104
+ oldEntropy = self ._calcEntropy (y )
105
+ bestGainRatio = 0.0
106
+ bestFeatureIndex = - 1
107
+ #对每个特征都计算一下gainRatio=infoGain/splitInformation
108
+ for i in range (numFeatures ):
109
+ featList = X [:,i ]
110
+ uniqueVals = set (featList )
111
+ newEntropy = 0.0
112
+ splitInformation = 0.0
113
+ #对第i个特征的各个value,得到各个子数据集,计算各个子数据集的熵,
114
+ #进一步地可以计算得到根据第i个特征分割原始数据集后的熵newEntropy
115
+ for value in uniqueVals :
116
+ sub_X ,sub_y = self ._splitDataSet (X ,y ,i ,value )
117
+ prob = len (sub_y )/ float (len (y ))
118
+ newEntropy += prob * self ._calcEntropy (sub_y )
119
+ splitInformation -= prob * np .log2 (prob )
120
+ #计算信息增益比,根据信息增益比选择最佳分割特征
121
+ #splitInformation若为0,说明该特征的所有值都是相同的,显然不能作为分割特征
122
+ if splitInformation == 0.0 :
123
+ pass
124
+ else :
125
+ infoGain = oldEntropy - newEntropy
126
+ gainRatio = infoGain / splitInformation
127
+ if (gainRatio > bestGainRatio ):
128
+ bestGainRatio = gainRatio
129
+ bestFeatureIndex = i
130
+ return bestFeatureIndex
131
+
132
+
133
+
134
+ def _majorityCnt (self ,labelList ):
135
+ """
136
+ 函数功能:返回labelList中出现次数最多的label
137
+ """
138
+ labelCount = {}
139
+ for vote in labelList :
140
+ if vote not in labelCount .keys (): labelCount [vote ] = 0
141
+ labelCount [vote ] += 1
142
+ sortedClassCount = sorted (labelCount .iteritems (),key = lambda x :x [1 ], reverse = True )
143
+ return sortedClassCount [0 ][0 ]
144
+
145
+
146
+
147
+ def _createTree (self ,X ,y ,featureIndex ):
148
+ """建立决策树
149
+ featureIndex,类型是元组,它记录了X中的特征在原始数据中对应的下标。
150
+ """
151
+ labelList = list (y )
152
+ #所有label都相同的话,则停止分割,返回该label
153
+ if labelList .count (labelList [0 ]) == len (labelList ):
154
+ return labelList [0 ]
155
+ #没有特征可分割时,停止分割,返回出现次数最多的label
156
+ if len (featureIndex ) == 0 :
157
+ return self ._majorityCnt (labelList )
158
+
159
+ #可以继续分割的话,确定最佳分割特征
160
+ if self ._mode == 'C4.5' :
161
+ bestFeatIndex = self ._chooseBestFeatureToSplit_C45 (X ,y )
162
+ elif self ._mode == 'ID3' :
163
+ bestFeatIndex = self ._chooseBestFeatureToSplit_ID3 (X ,y )
164
+
165
+ bestFeatStr = featureIndex [bestFeatIndex ]
166
+ featureIndex = list (featureIndex )
167
+ featureIndex .remove (bestFeatStr )
168
+ featureIndex = tuple (featureIndex )
169
+ #用字典存储决策树。最佳分割特征作为key,而对应的键值仍然是一棵树(仍然用字典存储)
170
+ myTree = {bestFeatStr :{}}
171
+ featValues = X [:,bestFeatIndex ]
172
+ uniqueVals = set (featValues )
173
+ for value in uniqueVals :
174
+ #对每个value递归地创建树
175
+ sub_X ,sub_y = self ._splitDataSet (X ,y , bestFeatIndex , value )
176
+ myTree [bestFeatStr ][value ] = self ._createTree (sub_X ,sub_y ,featureIndex )
177
+ return myTree
178
+
179
+ def fit (self ,X ,y ):
180
+ #类型检查
181
+ if isinstance (X ,np .ndarray ) and isinstance (y ,np .ndarray ):
182
+ pass
183
+ else :
184
+ try :
185
+ X = np .array (X )
186
+ y = np .array (y )
187
+ except :
188
+ raise TypeError ("numpy.ndarray required for X,y" )
189
+
190
+ featureIndex = tuple (['x' + str (i ) for i in range (X .shape [1 ])])
191
+ self ._tree = self ._createTree (X ,y ,featureIndex )
192
+ return self #allow chaining: clf.fit().predict()
193
+
194
+
195
+
196
+ def predict (self ,X ):
197
+ if self ._tree == None :
198
+ raise NotFittedError ("Estimator not fitted, call `fit` first" )
199
+
200
+ #类型检查
201
+ if isinstance (X ,np .ndarray ):
202
+ pass
203
+ else :
204
+ try :
205
+ X = np .array (X )
206
+ except :
207
+ raise TypeError ("numpy.ndarray required for X" )
208
+
209
+ def _classify (tree ,sample ):
210
+ """
211
+ 用训练好的决策树对输入数据分类
212
+ 决策树的构建是一个递归的过程,用决策树分类也是一个递归的过程
213
+ _classify()一次只能对一个样本(sample)分类
214
+ To Do: 多个sample的预测怎样并行化?
215
+ """
216
+ featIndex = tree .keys ()[0 ]
217
+ secondDict = tree [featIndex ]
218
+ key = sample [int (featIndex [1 :])]
219
+ valueOfkey = secondDict [key ]
220
+ if isinstance (valueOfkey , dict ):
221
+ label = _classify (valueOfkey ,sample )
222
+ else : label = valueOfkey
223
+ return label
224
+
225
+ if len (X .shape )== 1 :
226
+ return _classify (self ._tree ,X )
227
+ else :
228
+ results = []
229
+ for i in range (X .shape [0 ]):
230
+ results .append (_classify (self ._tree ,X [i ]))
231
+ return np .array (results )
232
+
233
+ def show (self ):
234
+ if self ._tree == None :
235
+ raise NotFittedError ("Estimator not fitted, call `fit` first" )
236
+
237
+ #plot the tree using matplotlib
238
+ import treePlotter
239
+ treePlotter .createPlot (self ._tree )
240
+
241
+
242
+ class NotFittedError (Exception ):
243
+ """
244
+ Exception class to raise if estimator is used before fitting
245
+
246
+ """
247
+ pass
0 commit comments