티스토리 뷰

k-nearest neighbors(k-최근접 이웃, kNN) 알고리즘은 이해하기 쉽고, 자주 사용되는 알고리즘이다.

 

그 이름에서 알 수 있듯, 비교 대상이 되는 데이터 포인트 주변에 가장 가까이 존재하는 k개의 데이터와 비교하여 가장 가까운 데이터 종류로 타깃 데이터를 판별하는 원리이다.

 

즉, 주어진 데이터셋 내에서 가장 가까운 k개의 이웃을 찾아, 그 이웃들의 Label 값으로 현재 데이터의 레이블 값을 결정한다. 

 

예를 들어, 붓꽃의 꽃받침 길이와 폭을 바탕으로 붓꽃의 종을 분류하는 문제를 생각해보면, k-최근접 이웃 알고리즘을 사용하여 붓꽃이 어떤 종에 속하는지 예측할 수 있다. 

 

새로운 붓꽃과 가장 가까운 k개의 붓꽃을 찾아 그 붓꽃들이 어떤 종에 속하는 지 보고, 그 종들 중 가장 많은 종으로 새로운 붓꽃의 종을 분류하는 방식이다.

 

k에 따라 예측 값은 변할 수 있으며, 타깃이 연속형 숫자인 경우에는 kNN으로 잡히는 k개의 데이터의 평균이 예측값이 된다.

 

아래는 사이킷런 라이브러리를 사용하여 꽃 데이터를 로드하여 k-최근접 이웃 알고리즘 문제를 푼 예제이다.

 

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

# 데이터 로드
raw_iris = datasets.load_iris()

# 피처 데이터와 타깃 데이터를 나눔
X = raw_iris.data
y = raw_iris.target

# 트레이닝, 테스트 데이터를 나눔
# X_train, X_test, y_train, y_test = train_test_split
X_tn, X_te, y_tn, y_te = train_test_split(X, y, random_state=0)

# 데이터 표준화
std_scale = StandardScaler()
# 데이터 표준화는 X_tn 데이터를 기준으로 진행
std_scale.fit(X_tn)

# fit된 표준화 방법에 트레이닝 피처 데이터인 X_tn데이터를 적용하여 실제로 표준화를 시킴
X_tn_std = std_scale.transform(X_tn)
X_te_std = std_scale.transform(X_te)

# 가장 근접한 데이터는 2개
clf_knn = KNeighborsClassifier(n_neighbors=2)

# 표준화된 트레이닝 피처 데이터 X_tn_std와 트레이닝 타깃 데이터 y_tn을 적용하여 학습을 시킨다.
clf_knn.fit(X_tn_std, y_tn)
print("데이터 표준화")
print(y_tn)

# 데이터 예측, 표준화된 테스트 피처 데이터인 X_te_std를 적용하여 예측한다.
knn_pred = clf_knn.predict(X_te_std)
print("데이터 예측")
print(knn_pred)

# 정확도 평가, 실제 y_te와 피처 데이터인 X_te_std를 기반으로 예측한 타깃 데이터를 비교하여 정확도를 평가할 수 있다.
accuracy = accuracy_score(y_te, knn_pred)
print("정확도 평가")
print(accuracy)

# confusion matrix(일치도) 확인
conf_matrix = confusion_matrix(y_te, knn_pred)
print("일치도 확인")
print(conf_matrix)

# 분류 리포트 확인
class_report = classification_report(y_te, knn_pred)
print("분류 리포트")
print(class_report)

◎ Output

더보기

데이터 표준화
[1 1 2 0 2 0 0 1 2 2 2 2 1 2 1 1 2 2 2 2 1 2 1 0 2 1 1 1 1 2 0 0 2 1 0 0 1
 0 2 1 0 1 2 1 0 2 2 2 2 0 0 2 2 0 2 0 2 2 0 0 2 0 0 0 1 2 2 0 0 0 1 1 0 0
 1 0 2 1 2 1 0 2 0 2 0 0 2 0 2 1 1 1 2 2 1 1 0 1 2 2 0 1 1 1 1 0 0 0 2 1 2 0]
데이터 예측
[2 1 0 2 0 2 0 1 1 1 1 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 2]
정확도 평가
0.9473684210526315
일치도 확인
[[13  0  0]
 [ 0 15  1]
 [ 0  1  8]]
분류 리포트
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        13
           1       0.94      0.94      0.94        16
           2       0.89      0.89      0.89         9

    accuracy                           0.95        38
   macro avg       0.94      0.94      0.94        38
weighted avg       0.95      0.95      0.95        38

Comments