Skip to content

Commit ef4337a

Browse files
committed
add Python implement of ID3,C4.5
1 parent c85c35a commit ef4337a

File tree

4 files changed

+384
-0
lines changed

4 files changed

+384
-0
lines changed

DecisionTree/id3_c45.py

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Fri Jul 10 22:04:33 2015
4+
5+
@author: wepon
6+
"""
7+
8+
import numpy as np
9+
10+
class DecisionTree:
11+
"""决策树使用方法:
12+
13+
- 生成实例: clf = DecisionTrees(). 参数mode可选,ID3或C4.5,默认C4.5
14+
15+
- 训练,调用fit方法: clf.fit(X,y). X,y均为np.ndarray类型
16+
17+
- 预测,调用predict方法: clf.predict(X). X为np.ndarray类型
18+
19+
- 可视化决策树,调用showTree方法
20+
21+
"""
22+
def __init__(self,mode='C4.5'):
23+
self._tree = None
24+
25+
if mode == 'C4.5' or mode == 'ID3':
26+
self._mode = mode
27+
else:
28+
raise Exception('mode should be C4.5 or ID3')
29+
30+
31+
32+
def _calcEntropy(self,y):
33+
"""
34+
函数功能:计算熵
35+
参数y:数据集的标签
36+
"""
37+
num = y.shape[0]
38+
#统计y中不同label值的个数,并用字典labelCounts存储
39+
labelCounts = {}
40+
for label in y:
41+
if label not in labelCounts.keys(): labelCounts[label] = 0
42+
labelCounts[label] += 1
43+
#计算熵
44+
entropy = 0.0
45+
for key in labelCounts:
46+
prob = float(labelCounts[key])/num
47+
entropy -= prob * np.log2(prob)
48+
return entropy
49+
50+
51+
52+
def _splitDataSet(self,X,y,index,value):
53+
"""
54+
函数功能:返回数据集中特征下标为index,特征值等于value的子数据集
55+
"""
56+
ret = []
57+
featVec = X[:,index]
58+
X = X[:,[i for i in range(X.shape[1]) if i!=index]]
59+
for i in range(len(featVec)):
60+
if featVec[i]==value:
61+
ret.append(i)
62+
return X[ret,:],y[ret]
63+
64+
65+
def _chooseBestFeatureToSplit_ID3(self,X,y):
66+
"""ID3
67+
函数功能:对输入的数据集,选择最佳分割特征
68+
参数dataSet:数据集,最后一列为label
69+
主要变量说明:
70+
numFeatures:特征个数
71+
oldEntropy:原始数据集的熵
72+
newEntropy:按某个特征分割数据集后的熵
73+
infoGain:信息增益
74+
bestInfoGain:记录最大的信息增益
75+
bestFeatureIndex:信息增益最大时,所选择的分割特征的下标
76+
"""
77+
numFeatures = X.shape[1]
78+
oldEntropy = self._calcEntropy(y)
79+
bestInfoGain = 0.0
80+
bestFeatureIndex = -1
81+
#对每个特征都计算一下infoGain,并用bestInfoGain记录最大的那个
82+
for i in range(numFeatures):
83+
featList = X[:,i]
84+
uniqueVals = set(featList)
85+
newEntropy = 0.0
86+
#对第i个特征的各个value,得到各个子数据集,计算各个子数据集的熵,
87+
#进一步地可以计算得到根据第i个特征分割原始数据集后的熵newEntropy
88+
for value in uniqueVals:
89+
sub_X,sub_y = self._splitDataSet(X,y,i,value)
90+
prob = len(sub_y)/float(len(y))
91+
newEntropy += prob * self._calcEntropy(sub_y)
92+
#计算信息增益,根据信息增益选择最佳分割特征
93+
infoGain = oldEntropy - newEntropy
94+
if (infoGain > bestInfoGain):
95+
bestInfoGain = infoGain
96+
bestFeatureIndex = i
97+
return bestFeatureIndex
98+
99+
def _chooseBestFeatureToSplit_C45(self,X,y):
100+
"""C4.5
101+
ID3算法计算的是信息增益,C4.5算法计算的是信息增益比,对上面ID3版本的函数稍作修改即可
102+
"""
103+
numFeatures = X.shape[1]
104+
oldEntropy = self._calcEntropy(y)
105+
bestGainRatio = 0.0
106+
bestFeatureIndex = -1
107+
#对每个特征都计算一下gainRatio=infoGain/splitInformation
108+
for i in range(numFeatures):
109+
featList = X[:,i]
110+
uniqueVals = set(featList)
111+
newEntropy = 0.0
112+
splitInformation = 0.0
113+
#对第i个特征的各个value,得到各个子数据集,计算各个子数据集的熵,
114+
#进一步地可以计算得到根据第i个特征分割原始数据集后的熵newEntropy
115+
for value in uniqueVals:
116+
sub_X,sub_y = self._splitDataSet(X,y,i,value)
117+
prob = len(sub_y)/float(len(y))
118+
newEntropy += prob * self._calcEntropy(sub_y)
119+
splitInformation -= prob * np.log2(prob)
120+
#计算信息增益比,根据信息增益比选择最佳分割特征
121+
#splitInformation若为0,说明该特征的所有值都是相同的,显然不能作为分割特征
122+
if splitInformation==0.0:
123+
pass
124+
else:
125+
infoGain = oldEntropy - newEntropy
126+
gainRatio = infoGain/splitInformation
127+
if(gainRatio > bestGainRatio):
128+
bestGainRatio = gainRatio
129+
bestFeatureIndex = i
130+
return bestFeatureIndex
131+
132+
133+
134+
def _majorityCnt(self,labelList):
135+
"""
136+
函数功能:返回labelList中出现次数最多的label
137+
"""
138+
labelCount={}
139+
for vote in labelList:
140+
if vote not in labelCount.keys(): labelCount[vote] = 0
141+
labelCount[vote] += 1
142+
sortedClassCount = sorted(labelCount.iteritems(),key=lambda x:x[1], reverse=True)
143+
return sortedClassCount[0][0]
144+
145+
146+
147+
def _createTree(self,X,y,featureIndex):
148+
"""建立决策树
149+
featureIndex,类型是元组,它记录了X中的特征在原始数据中对应的下标。
150+
"""
151+
labelList = list(y)
152+
#所有label都相同的话,则停止分割,返回该label
153+
if labelList.count(labelList[0]) == len(labelList):
154+
return labelList[0]
155+
#没有特征可分割时,停止分割,返回出现次数最多的label
156+
if len(featureIndex) == 0:
157+
return self._majorityCnt(labelList)
158+
159+
#可以继续分割的话,确定最佳分割特征
160+
if self._mode == 'C4.5':
161+
bestFeatIndex = self._chooseBestFeatureToSplit_C45(X,y)
162+
elif self._mode == 'ID3':
163+
bestFeatIndex = self._chooseBestFeatureToSplit_ID3(X,y)
164+
165+
bestFeatStr = featureIndex[bestFeatIndex]
166+
featureIndex = list(featureIndex)
167+
featureIndex.remove(bestFeatStr)
168+
featureIndex = tuple(featureIndex)
169+
#用字典存储决策树。最佳分割特征作为key,而对应的键值仍然是一棵树(仍然用字典存储)
170+
myTree = {bestFeatStr:{}}
171+
featValues = X[:,bestFeatIndex]
172+
uniqueVals = set(featValues)
173+
for value in uniqueVals:
174+
#对每个value递归地创建树
175+
sub_X,sub_y = self._splitDataSet(X,y, bestFeatIndex, value)
176+
myTree[bestFeatStr][value] = self._createTree(sub_X,sub_y,featureIndex)
177+
return myTree
178+
179+
def fit(self,X,y):
180+
#类型检查
181+
if isinstance(X,np.ndarray) and isinstance(y,np.ndarray):
182+
pass
183+
else:
184+
try:
185+
X = np.array(X)
186+
y = np.array(y)
187+
except:
188+
raise TypeError("numpy.ndarray required for X,y")
189+
190+
featureIndex = tuple(['x'+str(i) for i in range(X.shape[1])])
191+
self._tree = self._createTree(X,y,featureIndex)
192+
return self #allow chaining: clf.fit().predict()
193+
194+
195+
196+
def predict(self,X):
197+
if self._tree==None:
198+
raise NotFittedError("Estimator not fitted, call `fit` first")
199+
200+
#类型检查
201+
if isinstance(X,np.ndarray):
202+
pass
203+
else:
204+
try:
205+
X = np.array(X)
206+
except:
207+
raise TypeError("numpy.ndarray required for X")
208+
209+
def _classify(tree,sample):
210+
"""
211+
用训练好的决策树对输入数据分类
212+
决策树的构建是一个递归的过程,用决策树分类也是一个递归的过程
213+
_classify()一次只能对一个样本(sample)分类
214+
To Do: 多个sample的预测怎样并行化?
215+
"""
216+
featIndex = tree.keys()[0]
217+
secondDict = tree[featIndex]
218+
key = sample[int(featIndex[1:])]
219+
valueOfkey = secondDict[key]
220+
if isinstance(valueOfkey, dict):
221+
label = _classify(valueOfkey,sample)
222+
else: label = valueOfkey
223+
return label
224+
225+
if len(X.shape)==1:
226+
return _classify(self._tree,X)
227+
else:
228+
results = []
229+
for i in range(X.shape[0]):
230+
results.append(_classify(self._tree,X[i]))
231+
return np.array(results)
232+
233+
def show(self):
234+
if self._tree==None:
235+
raise NotFittedError("Estimator not fitted, call `fit` first")
236+
237+
#plot the tree using matplotlib
238+
import treePlotter
239+
treePlotter.createPlot(self._tree)
240+
241+
242+
class NotFittedError(Exception):
243+
"""
244+
Exception class to raise if estimator is used before fitting
245+
246+
"""
247+
pass

DecisionTree/readme.md

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
- ID3、C4.5的Python实现,其中C4.5有待完善,后续加入CART。
3+
4+
- 依赖
5+
- NumPy
6+
- Matplotlib
7+
8+
9+
- 测试
10+
11+
from id3_c45 import DecisionTree
12+
if __name__=='__main__':
13+
#Toy data
14+
X = [[1, 2, 0, 1, 0],
15+
[0, 1, 1, 0, 1],
16+
[1, 0, 0, 0, 1],
17+
[2, 1, 1, 0, 1],
18+
[1, 1, 0, 1, 1]]
19+
y = ['yes','yes','no','no','no']
20+
21+
clf = DecisionTree(mode='ID3')
22+
clf.fit(X,y)
23+
clf.show()
24+
print clf.predict(X) #['yes' 'yes' 'no' 'no' 'no']
25+
26+
clf_ = DecisionTree(mode='C4.5')
27+
clf_.fit(X,y).show()
28+
print clf_.predict(X) #['yes' 'yes' 'no' 'no' 'no']
29+
30+
**ID3:**
31+
32+
![](http://i.imgur.com/kqA3eHT.png)
33+
34+
**C4.5:**
35+
36+
![](http://i.imgur.com/ronxb97.png)
37+
38+
39+
- 存在的问题
40+
41+
(1) 如果测试集中某个样本的某个特征的值在训练集中没出现,则会造成训练出来的树的某个分支,对该样本不能分类,出现KeyError:
42+
43+
44+
from sklearn.datasets import load_digits
45+
dataset = load_digits()
46+
X = dataset['data']
47+
y = dataset['target']
48+
clf.fit(X[0:1000],y[0:1000])
49+
for i in range(1000,1500):
50+
try:
51+
print clf.predict(X[i])==y[i]
52+
except KeyError:
53+
print "KeyError"
54+
55+
(2)目前还不能对多个样本并行预测

DecisionTree/treePlotter.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
'''
2+
Created on Oct 14, 2010
3+
4+
@author: Peter Harrington
5+
6+
From the book <<Machine learning in action>>
7+
'''
8+
import matplotlib.pyplot as plt
9+
10+
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
11+
leafNode = dict(boxstyle="round4", fc="0.8")
12+
arrow_args = dict(arrowstyle="<-")
13+
14+
def getNumLeafs(myTree):
15+
numLeafs = 0
16+
firstStr = myTree.keys()[0]
17+
secondDict = myTree[firstStr]
18+
for key in secondDict.keys():
19+
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
20+
numLeafs += getNumLeafs(secondDict[key])
21+
else: numLeafs +=1
22+
return numLeafs
23+
24+
def getTreeDepth(myTree):
25+
maxDepth = 0
26+
firstStr = myTree.keys()[0]
27+
secondDict = myTree[firstStr]
28+
for key in secondDict.keys():
29+
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
30+
thisDepth = 1 + getTreeDepth(secondDict[key])
31+
else: thisDepth = 1
32+
if thisDepth > maxDepth: maxDepth = thisDepth
33+
return maxDepth
34+
35+
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
36+
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
37+
xytext=centerPt, textcoords='axes fraction',
38+
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
39+
40+
def plotMidText(cntrPt, parentPt, txtString):
41+
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
42+
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
43+
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
44+
45+
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
46+
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
47+
#depth = getTreeDepth(myTree)
48+
firstStr = myTree.keys()[0] #the text label for this node should be this
49+
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
50+
plotMidText(cntrPt, parentPt, nodeTxt)
51+
plotNode(firstStr, cntrPt, parentPt, decisionNode)
52+
secondDict = myTree[firstStr]
53+
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
54+
for key in secondDict.keys():
55+
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
56+
plotTree(secondDict[key],cntrPt,str(key)) #recursion
57+
else: #it's a leaf node print the leaf node
58+
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
59+
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
60+
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
61+
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
62+
#if you do get a dictonary you know it's a tree, and the first element will be another dict
63+
64+
def createPlot(inTree):
65+
fig = plt.figure(1, facecolor='white')
66+
fig.clf()
67+
axprops = dict(xticks=[], yticks=[])
68+
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
69+
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
70+
plotTree.totalW = float(getNumLeafs(inTree))
71+
plotTree.totalD = float(getTreeDepth(inTree))
72+
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
73+
plotTree(inTree, (0.5,1.0), '')
74+
plt.show()
75+
76+
77+
78+

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ CSDN:[wepon的专栏](http://blog.csdn.net/u012162613)
5656

5757
GMM和k-means作为EM算法的应用,在某种程度有些相似之处,不过GMM明显学习出一些概率密度函数来,结合相关理解写成python版本,详细介绍:[文章链接](http://blog.csdn.net/gugugujiawei/article/details/45583051)
5858

59+
- **DecisionTree**
60+
61+
Python、Numpy、Matplotlib实现的ID3、C4.5,其中C4.5有待完善,后续加入CART。文章待总结
62+
5963
##Contributing
6064

6165
欢迎加入本项目,任何机器学习/深度学习的demo都可以push进来,并且最好有相应的博文介绍代码。

0 commit comments

Comments
 (0)