Coding Diary.

(Machine Learning) 결정 트리, if-else문으로 구현하기 본문

Coding/Machine Learning

(Machine Learning) 결정 트리, if-else문으로 구현하기

life-of-nomad 2023. 10. 4. 09:03
728x90
반응형

1. 결정 트리

결정 트리는 예/아니오로 답할 수 있는 어떤 질문들이 있고, 그 질문들의 답을 따라가면서 데이터를 분류하는 알고리즘입니다.

교통사고가 났을 때, 운전자의 생존 여부를 예측하고 싶다고 합시다. 결정 트리는 질문들과 답으로 이루어졌습니다. 여기서 가장 위에 '안전벨트를 했나요?' 이 질문에서 안전벨트를 했으면 여기 왼쪽으로 내려와서 생존, 안했으면 오늘쪽으로 내려와서 사망, 이런 식으로 분류하는 것입니다. 질문에 해당하는 내용이 초록색, 그리고 분류에 해당하는 내용을 보라색이라고 합시다. 

 

지금은 안전벨트의 여부를 물었는데 데이터가 주행 속도, 즉 특정 숫자 값이라면 주행 속오가 시속 100km를 넘었나요? 와 같은 질문을 할 수 있습니다. 이때도 똑같이 예측하려는 데이터에 대한 질문의 답에 따라 왼쪽 또는 오른쪽으로 가서 데이터를 분류할 수 있습니다. 안전벨트를 했는지, 고속도로인지, 시속이 100km가 넘었는지, 사고자 나이가 50을 넘었는지 등을 이용해서 결정 트리를 만들 수 있습니다.

질문들에 답을 해가면서 한 단계씩 내려갈 수 있습니다. 보라색 박스들은 분류를 갖고 있습니다. 위에서부터 질문들에 계속 답을 하다가 내려가면서 이 보라색 박스들에 도착을 하면, 해당 분류 값을 리턴하면 됩니다.

 

그리고 중요한 것은 한 속성을 딱 한 번만 사용해야 되는 건 아닙니다.

 

예를 들어서 한 번은 주행 속도가 100을 넘엇었는지 질문할 수 있고 밑에 내려가서는 60을 넘었는지 이렇게 하나의 속성으로 여러 개의 질문을 만들 수도 있습니다.

 

질문들이 있고, 그 질문에 대한 답들을 따라가면서 데이터를 분류하는 알고리즘이라 직관적입니다.

 

컴퓨터 과학에서는 이렇게 한 지점에서 시작해서 점점 넓게 퍼져 나가는 걸 트리라고 합니다. 나무는 뿌리에서 시작해서 여러 개의 가지로 뻗어나갑니다. 

 

  • 하나의 시작 지점에서 퍼져나가는 모습이 마치 나무와 비슷
  • 한 단계 내려갈 때마다 왼쪽으로 갈지 오른쪽으로 갈지 선택

이러한 알고리즘이기 때문에 이름이 결정 트리인 것입니다.

 

이 하나하나에 있는 박스를 '노드' 라고 합니다. 가장 위에 있는 질문 노드를 나무의 뿌리라고 해서 'root 노드', 트리의 가장 끝에 있는 노드를 나무의 입, 'leaf 노드' 라고 합니다. leaf 노드는 항상 사망/생존과 같은 특정 예측값을 갖고 있고, 나머지 노드들은 예/아니오로 답할 수 있는 질문을 가지고 있습니다.

 

2. if-else 문으로 결정트리 구현하기

위 결정 트리를 if-else 문으로 구현해보겠습니다.

 

survival_classifier 함수

  • seat_belt : 안전 벨트를 했는지를 나타내는 불린형 파라미터
  • highway : 고속도로였는지를 나타내는 불린형 파라미터
  • speed : 사고 당시 주행속도를 나타내는 숫자형 파라미터
  • age : 사고자 나이를 나타내는 숫자형 파라미터
  • 생존을 예측할시 0을 리턴하고, 사망을 예측할 시 1을 리턴합니다.

code

def survival_classifier(seat_belt, highway, speed, age) : 
  # 질문 노드: 안전벨트를 했나요?
  if seat_belt: 
    return 0 # 했으면 생존 리턴
  else:
    # 질문 노드: 사고가 고속도로였나요?
    if highway:
      # 질문 노드 : 시속 100km를 넘었나요?
      if speed > 100:
        # 질문 노드: 사고자 나이가 50을 넘었나요?
        if age > 50:
          return 1 # 사고자 나이가 50을 넘었으면 사망 리턴
 
        else:
          return 0 # 사고자 나이가 50을 넘지 않았다면 생존 리턴
      else:
        return 0 # 시속 100km를 넘지 않았다면 생존 리턴
   else : 
      return 0 # 고속도로가 아니였다면 생존 리턴

#테스트코드
print(survival_classifier(False, True, 110, 55))
1

 

728x90
반응형