u012488416 于 2016.04.05 21:29 提问

from math import log
import operator
def CreateDataset():
dataset=[[0,1,1,'high'],
[0,1,1,'high'],
[0,1,1,'high'],
[0,0,1,'high'],
[0,1,1,'high'],
[0,0,1,'high'],
[0,1,0,'high'],
[1,1,1,'high'],
[1,1,0,'high'],
[1,1,1,'high'],
[1,1,1,'high'],
[1,1,1,'high'],
[1,1,1,'high'],
[0,1,1,'high'],
[1,0,1,'high'],
[1,0,1,'high'],
[1,0,1,'high'],
[1,0,1,'high'],
[1,0,0,'high'],
[0,0,0,'high'],
[0,0,1,'low'],
[0,0,1,'low'],
[0,0,1,'low'],
[0,0,0,'low'],
[0,1,0,'low'],
[1,0,1,'low'],
[1,0,1,'low'],
[0,0,0,'low'],
[0,0,0,'low'],
[1,0,0,'low'],
[0,1,0,'low'],
[1,0,1,'low'],
[1,0,0,'low'],
[1,0,0,'low']]
labels=['weather','weekend','sales','volumes']
return dataset,labels
def calcShannonEnt(dataset):
numEntries=len(dataset)
labelCounts={}
for featVec in dataset:
currentLabel=featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
shannonEnt-=prob *log(prob,2)
return shannonEnt

def splitDataSet(dataset,axis,value):
retDataSet=[]
for featVec in dataset:
if featVec[axis]==value:
reduceFeatVec=featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:])
retDataSet.append(reduceFeatVec)
return retDataSet

def chooseBestFeatureToSplit(dataset):
numFeatures=len(dataset[0])-1
baseEntropy=calcShannonEnt(dataset)
bestInfoGain=0.0
bestFeature=-1
for i in range(numFeatures):
featList=[example[i] for example in dataset ]
uniqueVals=set(featList)
newEntropy=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataset,i,value)
prob=len(subDataSet)/float(len(dataset))
newEntropy +=prob * calcShannonEnt(subDataSet)
infoGain=baseEntropy-newEntropy
if(infoGain>bestInfoGain):
bestInfoGain=infoGain
bestFeature=i
return bestFeature

def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]

def createTree(dataset,labels):
classList=[example[-1] for example in dataset]
if classList.count(classList[0])==len(classList):
return classList
if len(dataset[0])==1:
return majorityCnt(dataset)
bestFeat=chooseBestFeatureToSplit(dataset)
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues=[example[bestFeat] for example in dataset]
uniqueVals=set(featValues)
for value in uniqueVals:
subLabels=labels[:]
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataset,bestFeat,value),subLabels)
return myTree

myDat,labels = CreateDataset()
print(calcShannonEnt(myDat))

print(splitDataSet(myDat, 1, 1))

print(chooseBestFeatureToSplit(myDat))

print(createTree(myDat, labels))

2个回答

caozhy      2016.04.06 03:49
u012488416 这个算法我是知道的，这里的疑惑是不知道为什么执行会报错，希望大神不吝赐教

CSDNXIAOD   2016.04.05 21:32

id3决策树Python版
----------------------biu~biu~biu~~~在下问答机器人小D，这是我依靠自己的聪明才智给出的答案，如果不正确，你来咬我啊！