k最近傍法


概念

  • k個の最近傍を用いて分類する

    • kを大きくするとノイズを抑えられるが, 小さくても重要なパターンを無視してしまう

    • kを小さくしすぎるとノイズだらけのデータになる

    • 一般的に訓練データ数の平方根で設定するとよい(16種類ならk=4)

    • データ間の距離が重要なので, 特徴量を揃えることが求められる

      • min-max正規化\[X_{new}=\frac{X-min(X)}{max(X)-min(X)}\]

      • Zスコア標準化\[X_{new}=\frac{X-u}{\sigma}\]

      • one-hotエンコーディング

演習

  • 診断結果(diagnosis)が悪性(malignant)か良性(benign)かを予測する
# 前処理

library(here)
## here() starts at /Users/mac-user/Desktop/Statistics_Study/Machine_Learning_withR/MLwithR
library(tidyverse)
## ─ Attaching packages ──────────────────── tidyverse 1.3.0 ─
## ✓ ggplot2 3.3.3     ✓ purrr   0.3.4
## ✓ tibble  3.0.6     ✓ dplyr   1.0.4
## ✓ tidyr   1.1.2     ✓ stringr 1.4.0
## ✓ readr   1.4.0     ✓ forcats 0.5.1
## ─ Conflicts ───────────────────── tidyverse_conflicts() ─
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
## データの読み込み
wdcd <- read.csv(here("Chapter03", "wisc_bc_data.csv"), stringsAsFactors = F)
## id列の削除
wdcd <- wdcd %>% 
  select(!id) %>% 
  ## 因子に変換
  mutate(diagnosis = if_else(diagnosis == "B", "Benign", "Malignant")) %>% 
  mutate(diagnosis = as.factor(diagnosis))

## 陽性と陰性の割合
table(wdcd$diagnosis)
## 
##    Benign Malignant 
##       357       212
prop.table(table(wdcd$diagnosis))*100
## 
##    Benign Malignant 
##  62.74165  37.25835
## 正規化

### 関数の作成
normalize <- function(x){
  return((x - min(x)) / (max(x) - min(x)))
}

## 0~1に納める
wdcd_n <- sapply(wdcd[2:31], normalize) %>% 
  as.data.frame()
## 訓練データとテストデータに分ける
### 訓練データが469, テストデータが100
wdcd_train <- wdcd_n[1:469, ]
wdcd_test <- wdcd_n[470:569, ]

## 目的変数diagnosisを取り出してベクトルを作成
wdcd_train_labels <- wdcd[1:469, 1]
wdcd_test_labels <- wdcd[470:569, 1]

k最近傍法では指定したkの値の中で多数決が行われるので, 奇数にすると必ずどちらかになる

# knn
library(class)
## 予測結果がベクトルで戻ってくる
wdcd_test_pred <- knn(train = wdcd_train, test = wdcd_test,
                      cl = wdcd_train_labels, k = 21)
## どれくらい一致しているか評価する
library(gmodels)

CrossTable(x = wdcd_test_labels, y = wdcd_test_pred, prop.chisq = F)
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  100 
## 
##  
##                  | wdcd_test_pred 
## wdcd_test_labels |    Benign | Malignant | Row Total | 
## -----------------|-----------|-----------|-----------|
##           Benign |        61 |         0 |        61 | 
##                  |     1.000 |     0.000 |     0.610 | 
##                  |     0.968 |     0.000 |           | 
##                  |     0.610 |     0.000 |           | 
## -----------------|-----------|-----------|-----------|
##        Malignant |         2 |        37 |        39 | 
##                  |     0.051 |     0.949 |     0.390 | 
##                  |     0.032 |     1.000 |           | 
##                  |     0.020 |     0.370 |           | 
## -----------------|-----------|-----------|-----------|
##     Column Total |        63 |        37 |       100 | 
##                  |     0.630 |     0.370 |           | 
## -----------------|-----------|-----------|-----------|
## 
## 

98%正確に予測できているが, 実際は悪性なのに良性と判断しているのが2%ある. 適合率は100%だが, 再現率は98%である. ここでは取りこぼしをしたくないので再現率を高めるほうが好ましい

## Zスコア
wdcd_z <- scale(wdcd[-1]) %>% 
  as.data.frame()

wdcd_train_z <- wdcd_z[1:469, ]
wdcd_test_z <- wdcd_z[470:569, ]

wdcd_z_train_labels <- wdcd[1:469, 1]
wdcd_z_test_labels <- wdcd[470:569, 1]

wdcd_z_test_pred <- knn(train = wdcd_train_z, test = wdcd_test_z,
                      cl = wdcd_z_train_labels, k = 21)

CrossTable(x = wdcd_z_test_labels, y = wdcd_z_test_pred, prop.chisq = F)
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  100 
## 
##  
##                    | wdcd_z_test_pred 
## wdcd_z_test_labels |    Benign | Malignant | Row Total | 
## -------------------|-----------|-----------|-----------|
##             Benign |        61 |         0 |        61 | 
##                    |     1.000 |     0.000 |     0.610 | 
##                    |     0.924 |     0.000 |           | 
##                    |     0.610 |     0.000 |           | 
## -------------------|-----------|-----------|-----------|
##          Malignant |         5 |        34 |        39 | 
##                    |     0.128 |     0.872 |     0.390 | 
##                    |     0.076 |     1.000 |           | 
##                    |     0.050 |     0.340 |           | 
## -------------------|-----------|-----------|-----------|
##       Column Total |        66 |        34 |       100 | 
##                    |     0.660 |     0.340 |           | 
## -------------------|-----------|-----------|-----------|
## 
## 
## kの値を変えてみる

### k=1
wdcd_z_test_pred <- knn(train = wdcd_train_z, test = wdcd_test_z,
                      cl = wdcd_z_train_labels, k = 1)

CrossTable(x = wdcd_z_test_labels, y = wdcd_z_test_pred, prop.chisq = F)
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  100 
## 
##  
##                    | wdcd_z_test_pred 
## wdcd_z_test_labels |    Benign | Malignant | Row Total | 
## -------------------|-----------|-----------|-----------|
##             Benign |        59 |         2 |        61 | 
##                    |     0.967 |     0.033 |     0.610 | 
##                    |     0.952 |     0.053 |           | 
##                    |     0.590 |     0.020 |           | 
## -------------------|-----------|-----------|-----------|
##          Malignant |         3 |        36 |        39 | 
##                    |     0.077 |     0.923 |     0.390 | 
##                    |     0.048 |     0.947 |           | 
##                    |     0.030 |     0.360 |           | 
## -------------------|-----------|-----------|-----------|
##       Column Total |        62 |        38 |       100 | 
##                    |     0.620 |     0.380 |           | 
## -------------------|-----------|-----------|-----------|
## 
##