The Idea

\(k\)-means clustering is an unsupervised machine learning algorithm in which the goal is to group similar data points together in order to discover underlying patterns in the data. The algorithm aims to partition \(n\) observations into \(k\) clusters. A cluster is simply a collection of data aggregated because of certain similarities.

The standard algorithm looks to minimize the total within cluster variation (WSS) \[W(C_j) = \sum_{x_i \in C_j} (x_i - \mu_j)^2,\] where \(x_i\) is a data point belonging to cluster \(C_j\) and \(\mu_j\) is the mean value of the points assigned to cluster \(C_j\). Each observation is assigned to a cluster such that the sum of squared distances of the observation ot the assigned center \(\mu_j\) is minimum. The total within-cluster variation is \[ \text{Total Within SS} = \sum_{j=1}^k W(C_j) = \sum_{j=1}^k \sum_{x_i \in C_j} (x_i - \mu_j)^2\]

The Algorithm

The k-means algorithm begins with an initial set of randomly selected centroids (\(C_j\)), which are used as the initial centroids for every cluster. From there, the algorithm attempts to optimize the positions of the centroids in order to minimize the distances between the centroids and the points in the set. The algorithm terminates creating and optimizing clusters when either:

  1. The centroids have stabilized and no further movement of the centroid results from additional iterations. In this case, the clustering has been successful.

  2. A predefined number of iterations have occurred.

It is vital that data preparation be done in advance of running the algorithm. Rows are observations and columns are variables. Missing data values are removed or dealt with appropriately. Finally, all data must be scaled so that the algorithm is not impacted by the scale of the data.

Distance Metrics

The choice of distance metric is important. It will help influence the shape of the clusters. While there are some classic distance metrics such as the Euclidean metric \[d(\mathbf{u},\mathbf{v}) = \sqrt{\sum_{i=1}^n (u_i - v_i)^2},\] and the Manhattan metric \[d(\mathbf{u},\mathbf{v}) = \sum_{i=1}^n |u_i-v_i|, \] with \(\mathbf{u}\) and \(\mathbf{v} \in \mathbb{R}^n\), some other choices for measuring similarity (or dissimilarity) are also used.

For categorical variables, k-means is a poor algorithm to use. Since the goal of k-means is to minimize the distance, it makes no sense to do so for observations that have neither scale nor order.

An Example

Data Preprocessing

Take the USArrests data frame and we will cluster the data. First, we remove missing values and scale the data.

data("USArrests")       #US Arrests Data Set
df <- na.omit(df)       #Remove any missing data
df <- scale(USArrests)  #Scale the Data Set
set.seed(123)           #Reproducible seed
head(df, n=5)
##                Murder   Assault   UrbanPop         Rape
## Alabama    1.24256408 0.7828393 -0.5209066 -0.003416473
## Alaska     0.50786248 1.1068225 -1.2117642  2.484202941
## Arizona    0.07163341 1.4788032  0.9989801  1.042878388
## Arkansas   0.23234938 0.2308680 -1.0735927 -0.184916602
## California 0.27826823 1.2628144  1.7589234  2.067820292

Determining the Number of Clusters

Elbow Chart

Since the idea behind k-means is to minimize \(\sum_{j=1}^k W(C_j)\), one sensible way of determining the number of clusters is to simply run the clustering algorithm for different values of \(k\) and compare the total WSS from each value of \(k\). Ultimately, the larger the value of \(k\), the smaller the WSS should be. If we graph the WSS value against \(k\), we want to find the point on the chart (called an elbow chart) where the graph appears to bend (indicating a small decrease in the WSS relative to previous values of \(K\)). This can all be accomplished with the fviz_nbclust command.

set.seed(12345)                #Reproducible seed 
fviz_nbclust(df,               #Data
             kmeans,           #Algorithm
             method = "wss")   #Method used for estimating the optimal number of clusters.

Here, we see a substantial decrease from \(k=1\) to \(k=2\), another substantial decrease from \(k=2\) to \(k=4\) and then there is either no decrease or only small decreases. Thus, either 2 or 4 would be good candidates to try.

The fviz_nbclust command is a single function that accomplishes the task of the following code.

library(purrr)
set.seed(12345)

# function to compute total within-cluster sum of square 
wss <- function(k) {
  kmeans(df, k, nstart = 10 )$tot.withinss
}

# Compute and plot wss for k = 1 to k = 15
k.values <- 1:15

# extract wss for 2-15 clusters
wss_values <- map_dbl(k.values, wss)

plot(k.values, wss_values,
       type="b", pch = 19, frame = FALSE, 
       xlab="Number of clusters K",
       ylab="Total within-clusters sum of squares")

Average Silhouette Method

The average silhouette method measures how well each data point fits into their particular cluster. A high silhouette value indicates a good clustering.

set.seed(12345)                #Reproducible seed 
fviz_nbclust(df,               #Data
             kmeans,           #Algorithm
             method = "silhouette")   #Method used for estimating the optimal number of clusters.

We see that 2 is still a good candidate for the number of clusters. The fviz_nbclust command essentially takes the place of the following code. We again notice a small variation in the graph due to the random nature of the k-means algorithm.

library (cluster)
# function to compute average silhouette for k clusters
avg_sil <- function(k) {
  km.res <- kmeans(df, centers = k, nstart = 25)
  ss <- silhouette(km.res$cluster, dist(df))
  mean(ss[, 3])
}

# Compute and plot wss for k = 2 to k = 10
k.values <- 2:10

# extract avg silhouette for 2-10 clusters
avg_sil_values <- map_dbl(k.values, avg_sil)

plot(k.values, avg_sil_values,
       type = "b", pch = 19, frame = FALSE, 
       xlab = "Number of clusters K",
       ylab = "Average Silhouettes")

The Gap Statistic

This approach, which we do not go into in any detail, essentially compares the total intracluster variation for various \(k\) values to the expected values under a distribution with no obvious clustering.

set.seed(12345)                #Reproducible seed 
fviz_nbclust(df,               #Data
             kmeans,           #Algorithm
             method = "gap_stat")   #Method used for estimating the optimal number of clusters.

Again, it appears that either 2 or 4 clusters would work well.

Additional Methods

The NbClust package, published by Charrad et al., 2014, provides 30 indicators for determining the correct number of clusters and proposes the best clustering scheme from results obtained by varying all combinations of number of clusters, distance measures, and clustering methods.

Clustering

We select \(k=4\) clusters. The algorithm will run 50 times with a maximum of 25 iterations each time. The results will be the best of the 50 runs.

Analysis <- kmeans(df, 
                   centers = 4,     #4 clusters,
                   nstart = 50,     #50 restarts,
                   iter.max = 25)   #25 iterations max.
print(Analysis)
## K-means clustering with 4 clusters of sizes 13, 13, 8, 16
## 
## Cluster means:
##       Murder    Assault   UrbanPop        Rape
## 1 -0.9615407 -1.1066010 -0.9301069 -0.96676331
## 2  0.6950701  1.0394414  0.7226370  1.27693964
## 3  1.4118898  0.8743346 -0.8145211  0.01927104
## 4 -0.4894375 -0.3826001  0.5758298 -0.26165379
## 
## Clustering vector:
##        Alabama         Alaska        Arizona       Arkansas     California 
##              3              2              2              3              2 
##       Colorado    Connecticut       Delaware        Florida        Georgia 
##              2              4              4              2              3 
##         Hawaii          Idaho       Illinois        Indiana           Iowa 
##              4              1              2              4              1 
##         Kansas       Kentucky      Louisiana          Maine       Maryland 
##              4              1              3              1              2 
##  Massachusetts       Michigan      Minnesota    Mississippi       Missouri 
##              4              2              1              3              2 
##        Montana       Nebraska         Nevada  New Hampshire     New Jersey 
##              1              1              2              1              4 
##     New Mexico       New York North Carolina   North Dakota           Ohio 
##              2              2              3              1              4 
##       Oklahoma         Oregon   Pennsylvania   Rhode Island South Carolina 
##              4              4              4              4              3 
##   South Dakota      Tennessee          Texas           Utah        Vermont 
##              1              3              2              4              1 
##       Virginia     Washington  West Virginia      Wisconsin        Wyoming 
##              4              4              1              1              4 
## 
## Within cluster sum of squares by cluster:
## [1] 11.952463 19.922437  8.316061 16.212213
##  (between_SS / total_SS =  71.2 %)
## 
## Available components:
## 
## [1] "cluster"      "centers"      "totss"        "withinss"     "tot.withinss"
## [6] "betweenss"    "size"         "iter"         "ifault"

The output of the kmeans function is a list that includes important information (see: https://www.rdocumentation.org/packages/stats/versions/3.6.2/topics/kmeans for package documentation):

  1. cluster: A vector of integers, from 1 to \(k\), indicating the cluster to which each point is allocated.

  2. centers: A matrix of cluster centers.

  3. totss: The total sum of squares.

  4. withinss: Vector of within-cluster sum of squares, one component per cluster.

  5. tot.withinss: Total within-cluster sum of squares, i.e. sum(withinss).

  6. betweenss: The between-cluster sum of squares, i.e. \(totss-tot.withinss\).

  7. size: The number of points in each cluster.

The Results

We can see the attributes of each cluster by finding the mean of each cluster.

aggregate(USArrests, by = list(cluster=Analysis$cluster), mean)
##   cluster   Murder   Assault UrbanPop     Rape
## 1       1  3.60000  78.53846 52.07692 12.17692
## 2       2 10.81538 257.38462 76.00000 33.19231
## 3       3 13.93750 243.62500 53.75000 21.41250
## 4       4  5.65625 138.87500 73.87500 18.78125

So, in this example, cluster 1 is defined by high murder and assault rate with low urban population and rape rate.

Alternatively, the fviz_cluster command gives us a way to visualize the data. Since we have more than 2 components, fviz_cluster performs a principal component analysis (PCA) and plots the first 2 principle components.

fviz_cluster(Analysis, df)    #PCA for 2 components

                              #Inputs are the cluster analysis and data

You can, as an alternative, use a standard scatter plot to illustrate the clusters compared to the original variables.

df %>%
  as_tibble() %>%
  mutate(cluster = Analysis$cluster,
         state = row.names(USArrests)) %>%
  ggplot(aes(Murder, Assault, color = factor(cluster), label = state)) +
  geom_text()

A matrix of graphs is another way to display the data.

Plot1 <- df %>%
  as_tibble() %>%
  mutate(cluster = Analysis$cluster,
         state = row.names(USArrests)) %>%
  ggplot(aes(Murder, Assault, color = factor(cluster), label = state)) +
  geom_point() + theme(legend.position = "none")

Plot2 <- df %>%
  as_tibble() %>%
  mutate(cluster = Analysis$cluster,
         state = row.names(USArrests)) %>%
  ggplot(aes(Murder, Rape, color = factor(cluster), label = state)) +
  geom_point() + theme(legend.position = "none")

Plot3 <- df %>%
  as_tibble() %>%
  mutate(cluster = Analysis$cluster,
         state = row.names(USArrests)) %>%
  ggplot(aes(Murder, UrbanPop, color = factor(cluster), label = state)) +
  geom_point() + theme(legend.position = "none")

Plot4 <- df %>%
  as_tibble() %>%
  mutate(cluster = Analysis$cluster,
         state = row.names(USArrests)) %>%
  ggplot(aes(Rape, UrbanPop, color = factor(cluster), label = state)) +
  geom_point() + theme(legend.position = "none")

Plot5 <- df %>%
  as_tibble() %>%
  mutate(cluster = Analysis$cluster,
         state = row.names(USArrests)) %>%
  ggplot(aes(Assault, Rape, color = factor(cluster), label = state)) +
  geom_point() + theme(legend.position = "none")

Plot6 <- df %>%
  as_tibble() %>%
  mutate(cluster = Analysis$cluster,
         state = row.names(USArrests)) %>%
  ggplot(aes(Assault, UrbanPop, color = factor(cluster), label = state)) +
  geom_point() + theme(legend.position = "none")

library(grid)
library(gridExtra)
grid.arrange(Plot1, Plot2, Plot3, Plot4, Plot5, Plot6, nrow=2,top=textGrob("Scatter Plot of Clusters"))

Advantages of k-means

  1. k-means a simple and fast algorithm.

  2. k-means is able to efficiently deal with very large data sets.

Disadvantages of k-means

  1. k-means assumes prior knowledge of the data and requires the analyst to choose the appropriate number of clusters in advance. This can be overcome by completing the analysis for a variety of values of \(k\).

  2. The analysis is sensitive to the initial random selection of cluster centers. Thus, a different set of initial ceners may lead to different clustering results. However, with multiple runs and selecting the run with the lowest within-cluster sum of squares, you can often (though not always) overcome this issue.

  3. k-means is sensitive to outliers.

  4. Since binary variables have no natural center, and for a number of other reasons, k-means cannot be used for categorical variables.

  5. No cluster dendrogram is produced with k-means unlike with hierarchical clustering.

  6. If data is not able to be clustered, such as data from a uniform distribution, k-means will still cluster it.

  7. If clusters overlap, k-means does not have a way to measure the uncertainty for new data belonging to the overlapping regions nor is it able to determine the cluster to which the point should be assigned.

k-means and Shapes

In some cases, k-means does not pick up the pattern in the data. The diagram below by Dabbura illustrates two situations in which k-means will have a difficult time finding the correct clusters.

On the left, it seems like we should cluster outside and inside. In the diagram on the right, we have two parabolas that are clustered incorrectly again. In these instances, one needs to be careful about the type of clustering that is used.

Citations

Dabbura, Imad. “K-means Clustering - Algorithm, Applications, Evaluation Methods, and Drawbacks”. https://imaddabbura.github.io/post/kmeans-clustering/

Forgy, E. W. (1965). Cluster analysis of multivariate data: efficiency vs interpretability of classifications. Biometrics, 21, 768–769.

Garbade, Michael J. 2018. “Understanding K-means Clustering in Machine Learning.” https://towardsdatascience.com/understanding-k-means-clustering-in-machine-learning-6a6e67336aa1

Hartigan, J. A. and Wong, M. A. (1979). Algorithm AS 136: A K-means clustering algorithm. Applied Statistics, 28, 100–108. 10.2307/2346830.

Lloyd, S. P. (1957, 1982). Least squares quantization in PCM. Technical Note, Bell Laboratories. Published in 1982 in IEEE Transactions on Information Theory, 28, 128–137.

MacQueen, J. (1967). Some methods for classification and analysis of multivariate observations. In Proceedings of the Fifth Berkeley Symposium on Mathematical Statistics and Probability, eds L. M. Le Cam & J. Neyman, 1, pp.281–297. Berkeley, CA: University of California Press.