-
Notifications
You must be signed in to change notification settings - Fork 158
Open
Description
Lines 186 to 228 in 6d29eec
| ID3Tree = {bestFeatLabel: {}} | |
| del (labels[bestFeat]) | |
| # 得到列表包括节点所有的属性值 | |
| featValues = [example[bestFeat] for example in dataset] | |
| uniqueVals = set(featValues) | |
| if pre_pruning: | |
| ans = [] | |
| for index in range(len(test_dataset)): | |
| ans.append(test_dataset[index][-1]) | |
| result_counter = Counter() | |
| for vec in dataset: | |
| result_counter[vec[-1]] += 1 | |
| leaf_output = result_counter.most_common(1)[0][0] | |
| root_acc = cal_acc(test_output=[leaf_output] * len(test_dataset), label=ans) | |
| outputs = [] | |
| ans = [] | |
| for value in uniqueVals: | |
| cut_testset = splitdataset(test_dataset, bestFeat, value) | |
| cut_dataset = splitdataset(dataset, bestFeat, value) | |
| for vec in cut_testset: | |
| ans.append(vec[-1]) | |
| result_counter = Counter() | |
| for vec in cut_dataset: | |
| result_counter[vec[-1]] += 1 | |
| leaf_output = result_counter.most_common(1)[0][0] | |
| outputs += [leaf_output] * len(cut_testset) | |
| cut_acc = cal_acc(test_output=outputs, label=ans) | |
| if cut_acc <= root_acc: | |
| return leaf_output | |
| for value in uniqueVals: | |
| subLabels = labels[:] | |
| ID3Tree[bestFeatLabel][value] = ID3_createTree( | |
| splitdataset(dataset, bestFeat, value), | |
| subLabels, | |
| splitdataset(test_dataset, bestFeat, value)) | |
| if post_pruning: | |
| tree_output = classifytest(ID3Tree, | |
| featLabels=['年龄段', '有工作', '有自己的房子', '信贷情况'], | |
| testDataSet=test_dataset) |
标签传递问题:
第227行处传入的标签不应为全部特征名称列表,因为此处传入的test_dataset是剔除特征分量之后的测试集,所以标签列表也应当剔除对应的特征。
我的解决方案是在186行后插入代码featLabels = labels[:] # 深复制,再在227行处将featLabels作为参数传入。
测试子集为空的问题:
后剪枝是通过比对剪枝前后测试集的准确率来决定是否剪枝的,所以当测试子集为空时,在后剪枝的过程中计算准确率将出现除零错误。
我的解决方案是将225行的if post_pruning:改为if post_pruning and len(test_dataset) != 0:
欢迎留言讨论🙂
Metadata
Metadata
Assignees
Labels
No labels