def getNumLeafs(tree):
numOfLeaf = 0
firstNode = list(tree.keys())[0]
second = tree[firstNode]
# 测试节点的数据类型,若不是字典类型,则表示此节点为叶子节点
for key in second.keys():
if type(second[key]).__name__ == 'dict':
numOfLeaf += getNumLeafs(second[key])
else:
numOfLeaf += 1
return numOfLeaf
# 计算树的深度,在绘制决策树时确定y轴的高度
def getTreeDepth(tree):
depthOfTree = 0
firstNode = list(tree.keys())[0]
second = tree[firstNode]
for key in second.keys():
if type(second[key]).__name__ == 'dict':
thisNodeDepth = getTreeDepth(second[key]) + 1
else:
thisNodeDepth = 1
if thisNodeDepth > depthOfTree:
depthOfTree = thisNodeDepth
return depthOfTree
# 用matplotlib绘制决策树
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle='sawtooth', fc='0.8') # 决策节点;设置文本框的类型和文本框背景灰度,范围为0-1,0为黑,1为白,不设置默认为蓝色
leafNode = dict(boxstyle='round4', fc='1') # 设置叶子节点文本框的属性
arrow_args = dict(arrowstyle='<-')
# 绘制节点
# annotate(text,xy,xycoords,xytext,textcoords,va,ha,bbox,arrowprops)
# xy表示进行标注的点的坐标
# xytext表示标注的文本信息的位置
# xycoords与textcoords分别为xy和xytext的说明,默认为data
# va,ha设置文本框中文字的位置,va表示竖直方向,ha表示水平方向
def plotNode(nodeTxt, nodeIndex, parentNodeIndex, nodeType): # 形参:文本内容,文本的中心点,箭头指向文本的点,点的类型
plt.annotate(nodeTxt, xy=parentNodeIndex, xycoords='axes fraction',
xytext=nodeIndex, textcoords='axes fraction',
va='center', ha='center', bbox=nodeType,
arrowprops=arrow_args)
# 在父子节点之间添加注释
def plotMidText(thisNodeIndex, parentNodeIndex, text):
xmid = (parentNodeIndex[0] - thisNodeIndex[0]) / 2.0 + thisNodeIndex[0]
ymid = (parentNodeIndex[1] - thisNodeIndex[1]) / 2.0 + thisNodeIndex[1]
plt.text(xmid, ymid, text) # 在指定位置添加注释
def plotTree(tree, parentNodeIndex, midTxt):
global xOff
global yOff
numOfLeafs = getNumLeafs(tree)
nodeTxt, = tree.keys()
nodeIndex = (xOff + (1.0 + float(numOfLeafs)) / 2.0 / treeWidth, yOff) # 计算节点的位置
plotNode(nodeTxt, nodeIndex, parentNodeIndex, decisionNode)
plotMidText(nodeIndex, parentNodeIndex, midTxt)
secondDict = tree[nodeTxt]
yOff = yOff - 1.0 / treeDepth
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], nodeIndex, str(key))
else:
xOff = xOff + 1.0 / treeWidth
plotNode(secondDict[key], (xOff, yOff), nodeIndex, leafNode)
plotMidText((xOff, yOff), nodeIndex, str(key))
yOff = yOff + 1.0 / treeDepth
def createPlot(tree): # 绘制决策树的主函数
fig = plt.figure('DecisionTree', facecolor='white') # 创建一个画布,命名为'decisionTree',画布颜色为白色
fig.clf() # 清空画布
createPlot.ax1 = plt.subplot(111, frameon=False) # 111:将画布分成1行1列,去第一块画布;frameon:是否绘制矩形坐标框
# 设置两个全局变量xOff和yOff,追踪已绘制节点的位置,计算放置下一个节点的恰当位置。
global xOff
xOff = -0.5 / treeWidth
global yOff
yOff = 1.0
plotTree(tree, (0.5, 1.0), '')
plt.xticks([])
plt.yticks([])
plt.show()
def classify(inputTree, featureLabels, testVector):
firstNode,=inputTree.keys()
secondDict=inputTree[firstNode]
featureIndex=featureLabels.index(firstNode)
for key in secondDict.keys():
if testVector[featureIndex]==key:
if type(secondDict[key]).__name__=='dict':
classLabel=classify(secondDict[key], featureLabels, testVector)
else:
classLabel=secondDict[key]
return classLabel
def storeTree(inputTree, filename):
import pickle
file=open(filename, 'wb')
pickle.dump(inputTree, file)
file.close()
def loadTree(filename):
import pickle
file = open(filename, 'rb')
Tree = [pickle.load(file)]
file.close()
return Tree[0]
dataSet = [line.split() for line in open("feature.dat").readlines()]
labels = [line.split() for line in open("name.dat").readlines()]
decisionTree = createTree(dataSet, labels)
storeTree(decisionTree, 'decisionTree')
myTree = loadTree('decisionTree')
featureLabels = ['no surfacing', 'flippers']
treeWidth = float(getNumLeafs(myTree))
treeDepth = float(getTreeDepth(myTree))
createPlot(myTree)
print(classify(myTree, featureLabels, [1, 0]))
用matplotlib画构造的决策树报错,小白不知道什么原因,希望能帮忙改一下代码,谢谢