[Python] 의사결정트리

O`Reily 의 Programming Collective Intelligence책에서 참조

 

생각보다 이 책의 소스는 매우 강력하다. 바로 활용이 가능할 정도로.

의사결정트리를 만드는 과정은 여러 행과 열로 이루어진 데이터를 통해 분류하는 작업이라고 생각하면 쉽다.

핵심은 다음과 같음.

 

재귀함수를 이용해서 모든 조건을 쪼개서 데이터를 분리한다.

예) Column A가 ‘Male’ 가 참일때, 거짓일 때 > 참일때 분리되는 데이터를 다시 재귀함수로,

Column B가 숫자나 소숫점 형식이면 5.0 이상일 때, 5.2 이상일 때 등등으로 나누어 다시 재귀함수로.

여기에서 재귀함수로 넘기기 전에 엔트로피, 또는 지니 불순도를 이용하여, gain을 구한다.

gain은 현재 분류된 데이터(여러 결과가 혼재되어있는)의  score에서 어떠한 기준점으로 나뉘어진 두 set의 score의 점수를 제외함으로서

p=float(len(set1))/len(rows)
gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
if gain>best_gain and len(set1)>0 and len(set2)>0:

위의 같은 소스를 이용하더라. 여기서 scoref가 지니불순도 또는 엔트로피 함수.

이런식으로 가지를 계속 만들어가면서 트리를 구하게 된다.

 

말로 이해한 것을 적으려니 쉽지가 않다.

아무튼 이 부분의 이해를 위해 다음과 같은 예제를 만들었고,

XYGroup
0.61.91
1.21.51
2.70.11
2.51.41
2.31.11
0.21.21
0.11.11
0.41.21
0.81.21
2.60.11
0.31.21
2.50.91
2.321
0.50.21
1.22.31
2.92.21
0.631
2.10.71
2.52.91
0.72.51
1.31.81
1.91.21
1.90.41
1.90.21
1.71.31
0.32.91
1.511
1.92.11
1.81.71
1.621
2.20.82
5.82.32
2.60.22
51.22
3.31.22
2.42.42
4.62.52
4.90.42
4.70.22
4.61.32
3.71.12
2.90.82
4.21.32
31.32
5.60.62
5.20.82
3.41.62
3.512
2.52.42
3.72.42
32.12
2.32.72
4.22.32
5.41.72
5.61.92
4.21.72
5.522
2.30.22
4.61.32
4.32.12
0.24.13
0.63.63
0.95.53
0.24.83
0.653
0.623
0.24.63
0.53.23
0.36.33
0.71.53
0.15.13
0.26.83
0.23.13
0.35.83
0.66.83
0.43.43
0.45.33
0.83.33
13.43
03.13
0.81.43
0.92.93
11.63
0.75.53
0.74.63
0.61.43
0.51.93
0.34.23
0.95.93
0.633
4.95.84
4.66.44
4.274
4.17.84
4.17.94
4.55.14
3.66.84
3.45.84
57.44
4.66.34
4.284
4.67.44
3.45.64
3.27.44
4.17.44
3.57.24
47.94
3.85.64
4.45.54
4.184
3.27.94
45.24
4.684
4.564
4.37.74
4.85.74
3.16.64
3.45.74
3.35.24
3.17.54

X,Y좌표로만 이루어져 있다.

이를 평면에 표시하면

figure_1

위의 데이터는 의사결정트리에서 아래 그림처럼 분리된다.

treeview2

소스에 준비되어있는 함수를 이용해서 표시한 결과, 그래프에서 보면 알수 있듯이,

0번 컬럼(X좌표)가 3.0이상일 때 첫 분류를 시작했으며,

참인 경우, 1번 컬럼(Y좌표)가 5.1 이상일 때 2번 그룹과 4번 그룹으로 우선적으로 분류가 이루어지게 된다.

 

내가 이해한 것이 맞는지 모르겠지만, 이런 트리를 만들었을 때의 장점은 정보가 missing된 상태에서 그 missing된 값이 어떤 것일지 예측이 가능하다는 점이다. 내 예제는 x,y뿐이었지만 만약 x,y,z,q,p 등등 여러 변수로 구성된 트리라면, z,q값이 없을 때 x,y,p로만으로 각각의 결과가 어떤 확률로 분포될 수 있는지를 알 수 있다.

단점은 아마 내 예제처럼 그 분류 수준이 너무 낮다면, 두 그룹이 혼재되어있는 좌표, 예를 들자면 2.5, 1.3 이러한 좌표를 tree에 넣었을 때, 하나의 가지로 결국 이어져 그룹이 1번이라고 나오게 되어버린다. 두 그룹 중 어디에 포함될 것인지에 대한 확률로 구할 수가 없다는 것.

어느 그룹이 될지 확률적으로 구하기 위해서는 애초에 신경망 같은 다른 알고리즘을 사용하던지, 아니면 의사결정tree에서 적절한 pruning(가지치기) = 즉 일정 threshold까지는 엔트로피를 늘림으로서(=어느 정도의 엔트로피면 그 데이터 셋을 나누지 말고 놔둠) 분류를 단순화 하는 방법이다.

treeview3figure_1

이 경우에도 문제는 존재한다. 위의 그림은 가지치기의 한 예시다. 기준을 어떻게 잡느냐에 따라 더 단순해질수도, 조금 복잡해질수도 있다.

바로 첫번째 분기에서 왼쪽(false), 왼쪽(false)으로 가는 경우. 즉 위의 좌표 그림에서 x가 처음 3.0이하, 그리고 그 다음에 1.2이하로 간 경우인데, 분명 이 경우 Y좌표를 3정도로 기준을 두면 분명 더 세부적인 분류가 가능할 것으로 생각할 수 있다. 그러나 이 경우 그룹 1이 10개, 그룹 3이 30개라는 뭉뚱그려진 결과가 나와 더 불확실한 결과를 초래할 수 있게 되어버린다.

 

Leave a Reply

Your email address will not be published.