Machine Learning in Action ch.9.3 & 9.4 Regression Tree
9.3 회귀를 위해 CART 알고리즘 사용하기
9.3.1 트리 구축하기
회귀 트리를 구축하기 위해서 createTree 함수를 이용한다.
그리고 chooseBestSplit 함수를 이용해 분할하기에 가장 좋은 것으로 데이터 집합을 분할하고, 데이터 집합을 위해 단말 노드를 생성하는 두가지 작업을 한다.
def regLeaf(dataSet):#returns the value used for each leaf
return mean(dataSet[:,-1])
def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0]; tolN = ops[1]
#if all the target variables are the same value: quit and return value
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
return None, leafType(dataSet)
m,n = shape(dataSet)
#the choice of the best feature is driven by Reduction in RSS error from mean
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#if the decrease (S-bestS) is less than a threshold don't do the split
if (S - bestS) < tolS:
return None, leafType(dataSet) #exit cond 2
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3
return None, leafType(dataSet)
return bestIndex,bestValue#returns the best feature to split on
#and the value used for that split
첫번째 함수는 평균을 구하는 함수이다.
두번째 함수는 오류를 평가하는 함수이다.
세번째 함수는 분류 트리를 구축하는 함수이다. 즉, 데이터를 이진 분할하는 가장 좋은 방법을 찾는 것이다.
9.4 트리 가지치기
많은 노드를 가지고 있는 트리의 경우 , 과적합인 모델인 경우가 많다.
이번에는 데이터 과적합을 방지하는 방법과 그와 관련된 주변 지식에 관해서 공부한다.
과적합을 피하기 위한 의사결정 트리의 복잡성을 줄이는 방법으로 가지치기라는 것이 있다.
가지치기는 사전 가지치기와 사후 가지치기가 있다.
9.4.1 사전 가지치기
분할 전 평균제곱 오류 - 분할 후 평균제곱 오류 >= 최소값
9.4.2 사후 가지치기
여기서 사용하게 될 방법은 데이터 집합을 훈련과 검사 데이터 집합으로 나뉜 뒤.
트리를 구축한다.
트리는 단말 노드가 하나가 될 때까지 가지를 뻗어 내려간다.
그런 다음, 검사 집합에 있는 데이터를 가지고 단말노드를 검사한다.
그리고 검사 집합상에서 단말 노드를 병합하여 오류가 더 줄어드는 지를 측정한다.
def isTree(obj):
return (type(obj).__name__=='dict')
def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left']+tree['right'])/2.0
def prune(tree, testData):
if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree
if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
#if they are now both leafs, see if we can merge them
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
sum(power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
print ("merging")
return treeMean
else: return tree
else: return tree
댓글
댓글 쓰기