from math import log
def calcShannonEnt(dataSet):
num = len(dataSet)
labelCount = {}
for featureVec in dataSet:
label = featureVec[-1]
if label not in labelCount.keys():
labelCount[label] = 1
else:
labelCount[label] += 1
shannonEnt = 0.0
for key in labelCount.keys():
pro = float(labelCount[key]) / num
shannonEnt -= pro * log(pro, 2)
return shannonEnt
def splitDataSet(dataSet, feature, value):
reDataSet = []
for featureVector in dataSet:
if featureVector[feature] == value:
reduceFeature = featureVector[:feature]
reduceFeature.extend(featureVector[feature + 1:])
reDataSet.append(reduceFeature)
return reDataSet
def chooseBestFeatureToSplit(dataSet):
numOfFeature = len(dataSet[0]) - 1
baseShannon = calcShannonEnt(dataSet) #
bestShannon = 0.0
bestFeature = -1
for i in range(numOfFeature):
featureList = [featureVector[i] for featureVector in dataSet]
featureSet = set(featureList)
newShannon = 0.0
for value in featureSet:
subDataSet = splitDataSet(dataSet, i, value)
prob = float(len(subDataSet)) / float(len(dataSet))
newShannon += prob * calcShannonEnt(subDataSet)
shannon = baseShannon - newShannon
if (shannon > bestShannon):
bestShannon = shannon
bestFeature = i
return bestFeature
# 多数表决法定义叶子节点的分类
import operator
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 1
else:
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
# 递归构建决策树
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# 递归函数第一个停止的条件:所有类标签完全相同,直接返回该类标签
if classList.count(classList[0]) == len(classList):
return classList[0]
# 递归函数的第二个停止条件:使用完所有特征,仍不能将数据集划分成仅包含唯一类别的分组。使用多数表决法决定叶子节点的分类
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 开始创建决策树
bestFeature = chooseBestFeatureToSplit(dataSet) # 选择划分数据集最好的特征的索引
bestFeatureLabel = labels[bestFeature] # 根据特征的索引提取索引的名称
decisionTree = {bestFeatureLabel: {}} # 将此特征作为树的根节点
del labels[bestFeature] # 将已放进树中的特征从特征标签中删除
featrueValues = [example[bestFeature] for example in dataSet] # 提取所有样本关于这个特征的取值
uniqueVals = set(featrueValues) # 应用集合的互异性,提取这个特征的不同取值
for value in uniqueVals: # 根据特征的不同取值,创建这个特征所对应结点的分支
subLabels = labels[:]
decisionTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)
return decisionTree
def classify(inputTree, featureLabels, testVector):
firstNode,=inputTree.keys()
secondDict=inputTree[firstNode]
featureIndex=featureLabels.index(firstNode)
for key in secondDict.keys():
if testVector[featureIndex]==key:
if type(secondDict[key]).__name__=='dict':
classLabel=classify(secondDict[key], featureLabels, testVector)
else:
classLabel=secondDict[key]
return classLabel
def mapFeatureToLabelIndex(map, labels):
for key in map.keys():
for i in range(len(labels)):
if key == labels[i]:
return key, i
dataSet=[]
for line in open("feature.dat").readlines()[:100]:
dataSet.append([int(x) for x in list(line.strip().split(','))])
labels = [line.strip() for line in open("name.dat").readlines()]
testData=[]
for line in open("feature.dat").readlines()[100:200]:
testData.append([int(x) for x in list(line.strip().split(','))])
featureLabels = [line.split() for line in open("name.dat").readlines()]
def predict(testData, decisionTree, labels):
# 得到决策树结点的下标
feature_label, feature_index = mapFeatureToLabelIndex(decisionTree, labels)
tree = decisionTree[feature_label][testData[feature_index]]
# 判断该树是叶子结点仍是子结点
if (~isinstance(tree, dict)): # 若是是叶子结点,则直接返回结果
return tree
else: # 子结点则继续递归
return predict(testData, tree, labels)
import copy
predict_labels = copy.copy(labels)
myTree = createTree(dataSet, labels)
# 决策树准确率判断
def calPrecision(dataSet, predictSet):
length = len(dataSet)
count = 0
for i in range(length):
if dataSet[i][-1] == predictSet[i]:
count += 1
return count / length * 100
# 预测训练集
predict_result = []
for data in testData:
result = predict(data[0:-1], myTree, predict_labels)
predict_result.append(result)
# 测试训练集准确率
print("decision Tree predict precision: %.2f" % calPrecision(testData, predict_result), "%")
运行时报错
请问应该如何解决这个问题?