3 분 소요

결정 트리


1. 용어 정리

결정트리 : 예 / 아니오에 대한 질문을 이어나가면서 정답을 찾아 학습하는 알고리즘입니다.

불순도 : 결정 트리가 최적의 질문을 찾기 위한 기준입니다. 사이킷런은 지니 불순도와 엔트로피 불순도를 제공합니다.

정보 이득 : 부모 노드와 자식 노드의 불순도 차이입니다. 결정 트리 알고리즘은 정보 이득이 최대화되도록 학습합니다.

특성 중요도 : 결정 트리에 사용된 특성이 불순도를 감소하는데 기여한 정도를 나타내는 값입니다.


2. 로지스틱 회귀로 와인 분류하기

이번 미션은 알코올 도수, 당도, pH 값에 로지스틱 회귀 모델을 적용해 화이트 와인인지, 레드 와인인지 판별해야 하는게 이번 챕터의 목표이다.
wine 데이터셋을 이용해 데이터프레임으로 제대로 읽었는지 처음 5개의 샘플을 확인해보겠습니다.

import pandas as pd

wine = pd.read_csv('https://bit.ly/wine_csv_data')

wine.head()
alcohol sugar pH class
0 9.4 1.9 3.51 0.0
1 9.8 2.6 3.20 0.0
2 9.8 2.3 3.26 0.0
3 9.8 1.9 3.16 0.0
4 9.4 1.9 3.51 0.0

info() 메서드를 통해 데이터가 누락이 되있는지 확인을 해 봅니다.

wine.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6497 entries, 0 to 6496
Data columns (total 4 columns):
 #   Column   Non-Null Count  Dtype
---  ------   --------------  -----
 0   alcohol  6497 non-null   float64
 1   sugar    6497 non-null   float64
 2   pH       6497 non-null   float64
 3   class    6497 non-null   float64
dtypes: float64(4)
memory usage: 203.2 KB

이번에는 describe() 메서드를 통해 평균, 표준편차, 최소, 최대, 사분위수 등의 값을 확인해봅니다.

wine.describe()
alcohol sugar pH class
count 6497.000000 6497.000000 6497.000000 6497.000000
mean 10.491801 5.443235 3.218501 0.753886
std 1.192712 4.757804 0.160787 0.430779
min 8.000000 0.600000 2.720000 0.000000
25% 9.500000 1.800000 3.110000 1.000000
50% 10.300000 3.000000 3.210000 1.000000
75% 11.300000 8.100000 3.320000 1.000000
max 14.900000 65.800000 4.010000 1.000000

판다스 데이터 프레임을 넘파이 배열로 바꾼 후 이제 와인 데이터를 타깃과 데이터로 분류 후, 테스트 세트와 훈련 세트로 또 분류합니다. 각 개수를 6번 코드를 통해 확인 해보니 훈련 세트 5197개, 테스트 세트 1300개로 나뉘어집니다.

data = wine[['alcohol', 'sugar', 'pH']].to_numpy()
target = wine['class'].to_numpy()

from sklearn.model_selection import train_test_split

train_input, test_input, train_target, test_target = train_test_split(data, target, test_size=0.2, random_state=42)

print(train_input.shape, test_input.shape)
(5197, 3) (1300, 3)
from sklearn.linear_model import LogisticRegression

lr = LogisticRegression()
lr.fit(train_scaled, train_target)

print(lr.score(train_scaled, train_target))
print(lr.score(test_scaled, test_target))
0.7808350971714451
0.7776923076923077

이러한 계수와 절편은 일반인에게 설명하기는 어렵습니다. 우리도 이 모델이 이러한 값을 학습했는지 정확히 이해하기 어렵다. 이 숫자가 어떤 의미인지 설명하기 어렵기 때문에, 쉽게 설명하기 방법으로 결정 트리(Decision Tree) 모델을 사용합니다.

결정 트리 모델은 스무고개와 같이 질문을 하나씩 던져서 정답과 맞춰가는 겁니다.

print(lr.coef_, lr.intercept_)
[[ 0.51270274  1.6733911  -0.68767781]] [1.81777902]

3. 결정트리

사이킷런의 DecisionTreeClassfier 클래스를 사용해 결정 트리 모델을 훈련해 봅니다.
훈련 세트에 대한 점수는 매우 높지만, 그에 비해 테스트 세트의 성능은 낮은 것을 볼 수 있습니다. 즉 과대적합 모델이라고 볼 수 있습니다.

from sklearn.tree import DecisionTreeClassifier

dt = DecisionTreeClassifier(random_state=42)
dt.fit(train_scaled, train_target)

print(dt.score(train_scaled, train_target))
print(dt.score(test_scaled, test_target))
0.996921300750433
0.8592307692307692

결정 트리는 위에서 아래로 거꾸로 자랍니다. 맨 위의 노드를 루트 노드라고 하고 맨 아래 끝에 달린 노드를 리프 노드라고 한다. filled 매개변수에서 클래스에 맞게 노드의 색을 칠할 수 있습니다.feature_names 매개변수는 특성의 이름을 전달합니다.

plt.figure(figsize=(10,7))
plot_tree(dt, max_depth=1, filled=True, feature_names=['alcohol', 'sugar', 'pH'])
plt.show

png


png


png

하나씩 살펴보면 처음 루트 노드는 당도가 -0.239이하인지를 질문합니다. 어떤 샘플의 당도가 -0.239와 같거나 작으면 왼쪽으로 가고 그렇지 않으면 오른쪽 가지로 이동합니다. 총 샘플 5197개에서 왼쪽으로 2922개, 오른쪽으로 2275개가 이동한 것입니다. value는 [음성클래스 개수, 양성클래스 개수]입니다.

3 - 1. 불순도

png

빨간색 동그라미 친 gini는 지니 불순도를 의미합니다. 지니 불순도는 클래스의 비율을 제곱해서 더한 다음 1에서 빼면 됩니다. \(지니 불순도 = 1 - (음성 클래스 비율^2 + 양성 클래스 비율^2)\)


결정 트리 모델은 부모 노드와 자식 노드의 불순도 차이가 가능 크도록 트리를 성장시킵니다. 부모와 자식 노드 사이의 불순도 차이를 정보 이득(information gain)이라 합니다.


결정 트리도 ‘가지 치기’라는 작업이 필요합니다. 무작정 끝까지 자라면 훈련 세트에는 아주 잘 맞지만 테스트 세트에는 그에 못미치게 됩니다. DecisionTreeClassifier 클래스의 max_depth 매개변수를 3으로 지정해 모델을 만들어 봅니다.


훈련세트의 성능은 낮아졌지만 테스트 세트의 성능은 거의 그대로입니다.

dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(train_scaled, train_target)

print(dt.score(train_scaled, train_target))
print(dt.score(test_scaled, test_target))
0.8454877814123533
0.8415384615384616

plot_tree()함수로 그려봅니다.

plt.figure(figsize=(20,15))
plot_tree(dt, filled=True, feature_names=['alcohol', 'sugar', 'pH'])
plt.show()

png

두 번째 특성인 당도가 0.87로 제일 중요하고 그다음 알코올 도수, pH순으로 중요합니다.

print(dt.feature_importances_)
[0.12345626 0.86862934 0.0079144 ]

공부한 전체 코드는 깃허브에 올렸습니다. https://github.com/mgskko/Data_science_Study-hongongmachine/blob/main/%ED%98%BC%EA%B3%B5%EB%A8%B8%EC%8B%A0_5%EA%B0%95_1.ipynb

댓글남기기