K-means 為非監督式學習的演算法,將一群資料分成 k 群 (cluster),演算法上是透過計算資料間的距離來作為分群的依據,較相近的資料會成形成一群並透過加權計算或簡單平均可以找出中心點,透過多次反覆計算與更新各群中心點後,可以找出代表該群的中心點,之後便可以透過與中心點的距離來判定測試資料屬於哪一分群,或可進一步被用來資料壓縮,代表特定類別資料,以達到降低雜訊或填空值等議題。此為分割式分群法(partitional clustering)中的一種,藉由反覆運算,逐次降低誤差目標值,直到目標函式值不再變化或更低,就達到分群的最後結果。

分割式分群法目的是希望盡量減少每個分群中,每一資料點與群中心的距離平方差 (square error),假設一組包含 c 個群聚的資料,其中第 \(k\) 個群聚可用集合 \(G_k\) 表示,而 \(G_k\) 包含 \(n_k\) 筆資料 \(\{x_1, x_2, x_3, ... , x_{nk}\}\),此群聚中心為 \(y_k\),則該群聚的平方誤差 \(e_k\)\(e_k = \sum_i{|x_i-y_k|^2}\),其中 \(x_i\) 是屬於第 \(k\) 群的資料點。而這 c 個群聚的總合平方誤差 \(E\) 便是每個群聚的平方誤差總合,可稱為分群的誤差函數 (error fucntion) 或失真度 (distortion),\(E = \sum_{k=1~c^{ek}}\),故分群方法就變成一個最佳化問題,也就是說要如何選取 c 個群聚極其相關群中心,可促使 \(E\) 的值最小。

若用目標函式來說明,則假設給定一組 \(n\) 點資料 \(X = \{x_1,x_2,x_3, ... ,x_n\}\),每一資料點有 \(d\) 維,k-means 分群為找到一組 \(m\) 代表點 \(Y = \{y_1, y_2, y_3, ... ,y_m\}\),每個點亦是 \(d\) 維,促使下方目標函數越小越好:\(J(X, Y, U) = \sum_{i=1}^{n}{|x_i-y_k|^2}\)

k-means 主要演算法為: * 隨機取得 c 個資料點,分別視成 c 個分群的群中心,此即為 \(y\)。 * 由固定的 \(y\),找出最靠近的資料點 \(x\),並將之加入該群。 * 計算目標函數 \(J(x, y, u)\),若維持不變,代表分群結果已達最佳化,可以結束反覆計算。 * 若沒有收斂,則產生最佳的 \(y\),回到上述第二步驟。

K-means 在測試資料具有代表性或資料趨近於常態分布時有相當好的結果,但當訓練資料過少或不具代表性時,k-means 的分群結果相當的差,且會因訓練資料問題造成 k 值判定且易出現過度適應問題(overfitting),通常 k-means 的 k 值定義會有專業知識的判斷下來決定較容易有好的分群結果;但對於未知的資料時,則可以透過 k 的循序遞增或遞減等,查看資料間的分布差異,便可以了解 k 值可能為何為最佳。

R 實作

可以透過套件 stats 中的函式 kmeans 來實作。

套件安裝

packageName <- c("stats","graphics")
for(i in 1:length(packageName)) {
  if(!(packageName[i] %in% rownames(installed.packages()))) {
    install.packages(packageName[i])
  }
}
lapply(packageName, require, character.only = TRUE)
## [[1]]
## [1] TRUE
## 
## [[2]]
## [1] TRUE

資料準備

產生出類似常態分佈的測試資料。

set.seed(111)
rawdata <- rbind(
  matrix(rnorm(100, sd = 0.3), ncol = 2),
  matrix(rnorm(100, mean = 1, sd = 0.3), ncol = 2),
  matrix(rnorm(100, mean = 0.5, sd = 0.3), ncol = 2),
  cbind(rnorm(50, mean = 0, sd = 0.3), rnorm(50, mean = 1, sd = 0.3))
)
colnames(rawdata) <- c("x", "y")
plot(rawdata[,"x"], rawdata[,"y"], xlab="x", ylab="y")

建立 K-means 模型

# the prototype of kmeans
kmeans(x, centers, iter.max = 10, nstart = 1,
       algorithm = c("Hartigan-Wong", "Lloyd", "Forgy", "MacQueen"),
       trace=FALSE)
rawdata.kmeans <- kmeans(rawdata[,1:2], centers=4, iter.max = 20, trace = TRUE, algorithm = "Hartigan-Wong")
## KMNS(*, k=4): iter=  1, indx=2
##  QTRAN(): istep=200, icoun=47
##  QTRAN(): istep=400, icoun=95
##  QTRAN(): istep=600, icoun=156
## KMNS(*, k=4): iter=  2, indx=55
##  QTRAN(): istep=200, icoun=84
## KMNS(*, k=4): iter=  3, indx=200
summary(rawdata.kmeans)
##              Length Class  Mode   
## cluster      200    -none- numeric
## centers        8    -none- numeric
## totss          1    -none- numeric
## withinss       4    -none- numeric
## tot.withinss   1    -none- numeric
## betweenss      1    -none- numeric
## size           4    -none- numeric
## iter           1    -none- numeric
## ifault         1    -none- numeric

分群結果

  • 分群中心點特徵值
rawdata.kmeans.centers <- rawdata.kmeans$centers
rawdata.kmeans.centers
##             x          y
## 1  1.09953546 0.92811169
## 2  0.02976894 1.10087953
## 3 -0.15864434 0.06975866
## 4  0.53355186 0.42317730
  • 資料分群結果
rawdata.kmeans.res <- rawdata.kmeans$cluster
rawdata.kmeans.res
##   [1] 3 3 3 3 3 3 3 3 3 3 3 3 4 3 4 3 3 3 3 4 3 3 3 3 3 4 3 4 3 3 3 3 4 3 3
##  [36] 4 3 3 3 3 3 3 3 3 3 3 3 3 4 4 1 4 1 1 4 1 1 1 1 1 1 1 1 1 2 1 1 2 1 1
##  [71] 1 1 1 1 1 1 1 1 1 1 1 1 4 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 4 4 4 3 3
## [106] 4 4 4 4 4 4 4 4 1 4 4 4 1 4 4 4 2 4 2 4 2 4 3 2 4 4 4 2 1 1 4 4 4 4 4
## [141] 4 4 1 4 4 4 4 4 4 4 2 2 3 2 2 2 2 2 2 2 4 2 2 2 2 2 2 2 2 2 2 2 3 2 2
## [176] 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 4 2 2 2 2 2
  • 顯示結果
plot(rawdata[,"x"], rawdata[,"y"], xlab="x", ylab="y", col=rawdata.kmeans.res+2)
points(rawdata.kmeans.centers[,"x"], rawdata.kmeans.centers[,"y"], col="black", pch = 3, cex=1)

檢驗分群結果

  • 各群內的距離平方和 (sum of squares)
# the sum of squares
ss <- function(x) sum(scale(x, scale = FALSE)^2)
# rawdata.kmeans[[4]]
rawdata.kmeans$withinss
## [1] 6.712847 7.419503 6.955472 5.695341
  • 各群內的距離平方和總量 = “各群內的距離平方和” 的 總合
rawdata.kmeans$tot.withinss
## [1] 26.78316
  • 分群間的聚類平方和總量
rawdata.kmeans$betweenss
## [1] 79.70939