1 The origin of this note

在這個大套件時代,能否手刻似乎是一種浪漫,但是手刻是體現你有沒有正確理解演算法的最不容質疑的方式,必須要懂細節,並且知道如何去implement it!

還有網路上其實沒有很多R語言的手刻範例,即使有也不見得是中文的,也很少到逐步去解釋programmer的思路的筆記

R code in github

2 Introduce the algorithm of Knn

1.first calculate the distance between Xtrain and Xtest

2.choose k this is hyper parameter in knn algorithm

3.vote the majority in the k neighbor

data(iris)
class(iris$Species) # check
## [1] "factor"
set.seed(1)
data <- iris[sample(nrow(iris))   ,] # 打散資料
Xtrain <- data[1:100,1:4]
label <- data[1:100,5]
Xtest <- data[101:150,1:4]
testlabel <- data[101:150,5]
data.table(iris)
##      Sepal.Length Sepal.Width Petal.Length Petal.Width   Species
##   1:          5.1         3.5          1.4         0.2    setosa
##   2:          4.9         3.0          1.4         0.2    setosa
##   3:          4.7         3.2          1.3         0.2    setosa
##   4:          4.6         3.1          1.5         0.2    setosa
##   5:          5.0         3.6          1.4         0.2    setosa
##  ---                                                            
## 146:          6.7         3.0          5.2         2.3 virginica
## 147:          6.3         2.5          5.0         1.9 virginica
## 148:          6.5         3.0          5.2         2.0 virginica
## 149:          6.2         3.4          5.4         2.3 virginica
## 150:          5.9         3.0          5.1         1.8 virginica

3 Implement Knn with out using package

3.1 Calculate distance matrix

我的思路如下,如果要計算所有訓練樣本(M個)和測樣樣本(N個)的距離的話,距離矩陣的 dim(距離矩陣)=M*N

但這實在是有點難想,但不如我們先來想一下,如果只計算訓練樣本的第一筆資料和測試樣本的第一筆資料該怎麼做呢

Xtrain[1,]
##    Sepal.Length Sepal.Width Petal.Length Petal.Width
## 68          5.8         2.7          4.1           1
Xtest[1,]
##    Sepal.Length Sepal.Width Petal.Length Petal.Width
## 35          4.9         3.1          1.5         0.2
Xtrain[1,]-Xtest[1,]
##    Sepal.Length Sepal.Width Petal.Length Petal.Width
## 68          0.9        -0.4          2.6         0.8
(Xtrain[1,]-Xtest[1,])^2 
##    Sepal.Length Sepal.Width Petal.Length Petal.Width
## 68         0.81        0.16         6.76        0.64
sum((Xtrain[1,]-Xtest[1,])^2)
## [1] 8.37
sum((Xtrain[1,]-Xtest[1,])^2) %>% sqrt()
## [1] 2.893095

很簡單對吧!

那如果是Xtrain 1th 和 Xtest 2th 的距離呢 ?

Xtrain[1,]-Xtest[1,]
##    Sepal.Length Sepal.Width Petal.Length Petal.Width
## 68          0.9        -0.4          2.6         0.8
(Xtrain[1,]-Xtest[2,])^2 
##    Sepal.Length Sepal.Width Petal.Length Petal.Width
## 68         0.16        0.04         0.04        0.09
sum((Xtrain[1,]-Xtest[2,])^2)
## [1] 0.33
sum((Xtrain[1,]-Xtest[2,])^2) %>% sqrt()
## [1] 0.5744563

所以其實雙重迴圈的結構就可以幫我們計算遍歷的結果

以下我們的k=3,也就是找最像的3個鄰居,然後要去看這3個鄰居的label是啥喔!(這樣才能做多數決)

k=3
M <- nrow(Xtrain)
N <- nrow(Xtest)

distmatrix <- matrix(0,nrow = M,ncol = N)

for(i in 1:M){
  for(j in 1:N){
    distmatrix[i,j]<- sum((Xtrain[i,]-Xtest[j,])^2) %>% sqrt()
  }
}
distmatrix[1,1] #for check
## [1] 2.893095

複習一下,待會會用到喔!

x=c(2.5,7,0,2.2,5.7)
sort(x)
## [1] 0.0 2.2 2.5 5.7 7.0
order(x)
## [1] 3 4 1 5 2

要理解order也很簡單,你去看sort後的結果是由小排到大對吧,所以order的第一個元素就是返回最小值在原x中是第幾個!

0在原數列是第3個

2.2在原數列當中是第4個

2.5在原數列當中是第1個

超簡單對吧!

3.2 return order in the distance matrix

sortedDistIndexes <- apply(distmatrix,2,order)
dim( sortedDistIndexes   )
## [1] 100  50

sortedDistIndexes is a matrix which dim is 100*50 ,and it return order in the distance matrix

這個矩陣有點難以解釋但是看以下的解釋就會懂了

這裡返回,對第1筆測試資料來說,距離最近的3筆traing data 在 traing data中是第幾筆(這句話一定要看懂),這樣我們就可以找出他們的label

上面這一小段一定要看懂,除了這一段其他都很straightforward

sortedDistIndexes[1:k,1]
## [1] 79 44 71

記得回去label 向量中找尋答案阿哥

label[sortedDistIndexes[1:k,1]]
## [1] setosa setosa setosa
## Levels: setosa versicolor virginica
label[sortedDistIndexes[1:k,2]]
## [1] versicolor versicolor versicolor
## Levels: setosa versicolor virginica
label[sortedDistIndexes[1:k,3]]
## [1] virginica  versicolor virginica 
## Levels: setosa versicolor virginica
#label #you can verify it

把剛剛做的事情對所有的測試資料都做一遍

#matrix(label[sortedDistIndexes[1:k,1:nrow(Xtest)]],ncol = k,byrow = T)

ans<- matrix(label[sortedDistIndexes[1:k,1:nrow(Xtest)]],ncol = k,byrow = T)

dim(ans)
## [1] 50  3
ans[1:3,]
##      [,1]         [,2]         [,3]        
## [1,] "setosa"     "setosa"     "setosa"    
## [2,] "versicolor" "versicolor" "versicolor"
## [3,] "virginica"  "versicolor" "virginica"

ans is a matrix which dim is 50*3 because we have 50 testing data and the k here is 3

we will use majority rule to find the answer

For example the 1th testing data should be “setosa” and the 3th testing data should be “virginica”

3.3 show the detail of how to find answer

ans[3,]
## [1] "virginica"  "versicolor" "virginica"
table(ans[3,])
## 
## versicolor  virginica 
##          1          2
which.max(table(ans[3,]))
## virginica 
##         2
names( which.max(table(ans[3,]))     )
## [1] "virginica"
findmajority<-function(x) {
  names( which.max(table(x))     )
} 

findmajority(ans[3,]  )
## [1] "virginica"

the table function will return a table contain how many species

the which.max function return the highest frequency species

the names function help us to return the name in table

#apply(ans,1,  fucntion(i){ names( which.max(table(i))  }   )can not work

ans %>% apply( . ,1,findmajority   )
##  [1] "setosa"     "versicolor" "virginica"  "virginica"  "virginica" 
##  [6] "versicolor" "virginica"  "versicolor" "setosa"     "virginica" 
## [11] "setosa"     "versicolor" "setosa"     "virginica"  "virginica" 
## [16] "setosa"     "versicolor" "virginica"  "virginica"  "versicolor"
## [21] "versicolor" "versicolor" "setosa"     "setosa"     "setosa"    
## [26] "versicolor" "setosa"     "versicolor" "versicolor" "virginica" 
## [31] "setosa"     "versicolor" "virginica"  "virginica"  "versicolor"
## [36] "versicolor" "virginica"  "setosa"     "versicolor" "versicolor"
## [41] "setosa"     "setosa"     "versicolor" "setosa"     "virginica" 
## [46] "setosa"     "virginica"  "setosa"     "versicolor" "versicolor"
finalanswer<- ans %>% apply( . ,1,findmajority   )

finalanswer <- as.factor(finalanswer)

對答案囉

finalanswer ==testlabel
##  [1]  TRUE  TRUE FALSE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE
## [13]  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE
## [25]  TRUE FALSE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE
## [37]  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE
## [49]  TRUE  TRUE

4 總結

寫程式的思路很簡單,不會寫function那就一步一步做,不會一次寫對全部的矩陣,那就寫對一列,不會寫對一列,那就寫對一個!

rm(list=ls()) # cleaning the environment
require(magrittr)
data(iris)
class(iris$Species)
## [1] "factor"
set.seed(1)
data <- iris[sample(nrow(iris))   ,]
Xtrain <- data[1:100,1:4]
Xtest <- data[101:150,1:4]
label <- data[1:100,5]
testlabel <- data[101:150,5]
my_knn <- function(Xtrain,Xtest,label,k     ){
M <- nrow(Xtrain)
N <- nrow(Xtest)
distmatrix <- matrix(0,nrow = M,ncol = N)
for(i in 1:M){
  for(j in 1:N){
    distmatrix[i,j]<- sum((Xtrain[i,]-Xtest[j,])^2) %>% sqrt()
  }
}
sortedDistIndexes <- apply(distmatrix,2,order)

ans<- matrix(label[sortedDistIndexes[1:k,1:nrow(Xtest)]],ncol = k,byrow = T)

findmajority<-function(x) {
  names( which.max(table(x))     )
} 
finalanswer<- ans %>% apply( . ,1,findmajority   )
finalanswer <- as.factor(finalanswer)  
return(finalanswer)
}
predicted<- my_knn( Xtrain,Xtest,label,k=4)
predicted==testlabel
##  [1]  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE
## [13]  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE
## [25]  TRUE FALSE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE
## [37]  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE  TRUE
## [49]  TRUE  TRUE
require(class)
## Loading required package: class
packageanswer<- knn(Xtrain,Xtest,label,k=4    )
predicted==packageanswer
##  [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## [16] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## [31] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## [46] TRUE TRUE TRUE TRUE TRUE

比較一下自己寫的function跟套件的速度

system.time(my_knn( Xtrain,Xtest,label,k=4))
##    user  system elapsed 
##    4.64    0.00    4.66
system.time( knn(Xtrain,Xtest,label,k=4    )     )
##    user  system elapsed 
##       0       0       0

第一個是我的速度,第二個是套件的速度,果然差很多QQ