In this brief tutorial, we will cover how to implement the K-Means Clustering algorithm using US arrests data to see if we can identify any clusters within that data set.
K-Means Clutering is an unsupervised machine learning clustering algorithm that attempts to group observations into different clusters. Specifically, the goal of the algorithm is to minimize the difference within clusters and maximize the difference between clusters. Below is a graphical representation of the clusters that could be formed using the algorithm.
To perform K-means clustering, the first step is to specify the number of clusters \(K\). Then, the K-means algorithm will assign each observation to exactly one of the \(K\) clusters. It is a fairly simple and intuitive mathematical problem.
We begin by defining some notation. Let \(C_{1}, \ldots, C_{K}\) represent sets containing the indices of the observati0ons in each cluster. These sets satisfy the following properties
For example, if the \(i\)th observation is in the \(k\)th cluster then \(i \in C_{k}\).
The within-variation for cluster \(C_{k}\) is a measure \(W(C_{k})\) of the amount by which the observations within a cluster differ from each other. Hence, we would like to solve the following
\[ \begin{aligned} \min_{C_{1}, \ldots, C_{K}}\Bigg\{\sum_{k=1}^{K}{W(C_{K})}\Bigg\} &&&& (1) \end{aligned} \]
Equation (1) states that we would like to partition the observations into \(K\) clusters such that the total within-cluster variation, summed over all \(K\) clusters, is as small as possible. Solving for equation (1) may seem like a reasonable idea, but in order to make it actionable, we first need to define the within-cluster variation. There are many possible ways to define this concept, but the most common choice inolves using the squared Euclidean distance, which is defined as
\[ \begin{aligned} W(C_{K}) = \frac{1}{\vert C_{K}\vert} \sum_{i, i' \in C_{K}} \sum_{j=1}^{p} {(x_{ij} - x_{i'j})^{2}} &&&& (2) \end{aligned} \]
where \(\vert C_{K} \vert\) denotes the number of observations in the \(k\)th cluster. That is, the within-cluster variation for the \(k\)th cluster is the sum of all of the pairwise squared Euclidean distances between the observations in the clusuter, divided by the total number of observations in the \(k\)th cluster.
If we combine equation (1) and (2), we get the optimization problem that defines K-means clustering
\[ \begin{aligned} \min_{C_{1}, \ldots, C_{K}} \Bigg\{\sum_{k=1}^{K}{\frac{1}{\vert C_{K}\vert}} \sum_{i,i' \in C_{K}} \sum_{j=1}^{p}{(x_{ij} - x_{i'j})^{2}}\Bigg\} &&&& (3) \end{aligned} \]
Now, we would like to solve for (3). In other words, we need a method to partition the observations into \(K\) clusters such that the objective of (3) is minimized. This is actually quite a difficult task to solve precisely, since there are almost \(K^{n}\) ways to partition \(n\) observations into \(K\) clusters. This is a huge number unless \(K\) and \(n\) are extremely small! Fortunately, there exists a fairly simple algorithm that can be shown to providie a local optimum - a decent solution - to the K-means optimization problme in (3). The approach is as follows
The above algorithm is guaranteed to decrease the value of the objective (3) at each steep. To see why, consider the following identity
\[ \begin{aligned} \frac{1}{\vert C_{K} \vert} \sum_{i,i' \in C_{K}} \sum_{j=1}^{p}{(x_{ij} - x_{i'j})^{2}} = 2 \sum_{i \in C_{K}}\sum_{j=1}^{p}{(x_{ij} - \bar{x}_{kj})^{2}} &&&& (4) \end{aligned} \]
where \(\bar{x}_{kj} = \frac{1}{\vert C_{K} \vert} \sum_{i \in C_{K}}{x_{ij}}\) is the mean for feature \(j\) in cluster \(C_{K}\).
In step 2a, the cluster means for each feature are the constants that minimize the sum-of-squared deviations, and in step 2b, reallocating the observations can only improve (4). This implies that as the algorithm is being run, the clustering obtained will continually improve until the result no longer changes. The objective of (3) will never increase. When the result no longer changes, a local optimum has been reached.
Because K-means finds a local rather than a global optimum, the results obtained will depend on the initial random cluster assignment of each observation in step 1 of the algorithm. As a result, it is crucial to run the algorithm multiple times from different random intial configurations. Then, one selects the best solution. That is, that for which the objective (3) is smallest.
The data set we will be working with for our example is US arrests by state, where the data set includes the features (1) number of murders, (2) number of assaults, (3) urban population, and (4) number of rapes.
First, let’s conduct some explanatory data analysis.
## Warning: package 'ISLR' was built under R version 3.4.2
Murder | Assault | UrbanPop | Rape | |
---|---|---|---|---|
Alabama | 13.2 | 236 | 58 | 21.2 |
Alaska | 10.0 | 263 | 48 | 44.5 |
Arizona | 8.1 | 294 | 80 | 31.0 |
Arkansas | 8.8 | 190 | 50 | 19.5 |
California | 9.0 | 276 | 91 | 40.6 |
Colorado | 7.9 | 204 | 78 | 38.7 |
Connecticut | 3.3 | 110 | 77 | 11.1 |
Delaware | 5.9 | 238 | 72 | 15.8 |
Florida | 15.4 | 335 | 80 | 31.9 |
Georgia | 17.4 | 211 | 60 | 25.8 |
Hawaii | 5.3 | 46 | 83 | 20.2 |
Idaho | 2.6 | 120 | 54 | 14.2 |
Illinois | 10.4 | 249 | 83 | 24.0 |
Indiana | 7.2 | 113 | 65 | 21.0 |
Iowa | 2.2 | 56 | 57 | 11.3 |
## Murder Assault UrbanPop Rape
## Min. : 0.800 Min. : 45.0 Min. :32.00 Min. : 7.30
## 1st Qu.: 4.075 1st Qu.:109.0 1st Qu.:54.50 1st Qu.:15.07
## Median : 7.250 Median :159.0 Median :66.00 Median :20.10
## Mean : 7.788 Mean :170.8 Mean :65.54 Mean :21.23
## 3rd Qu.:11.250 3rd Qu.:249.0 3rd Qu.:77.75 3rd Qu.:26.18
## Max. :17.400 Max. :337.0 Max. :91.00 Max. :46.00
murder.hist <- ggplot(USArrests, aes(x = Murder)) + geom_histogram(fill = 'tomato2', col = 'black', alpha = 0.8)
assault.hist <- ggplot(USArrests, aes(x = Assault)) + geom_histogram(fill = 'yellow', col = 'black', alpha = 0.9)
rape.hist <- ggplot(USArrests, aes(x = Rape)) + geom_histogram(fill = 'purple', col = 'black', alpha = 0.78)
murder.hist
assault.hist
rape.hist
Choosing the level of \(K\) is more of an art than a science (although there are mathematical methods in which you can choose \(K\) such as the Elbow Method). Ideally, one should use intuition based on what you are trying to achieve. For example, if you wanted to cluster customers according to income class, you could, for example, choose \(K = 3\), where each cluster represents high, middle, and low income earners. As a general of thumb, one should use intuition in choosing the level of \(K\). However, if you want a more mathematical method of choosing \(K\), then the elbow method, as mentioned above, lets you know what the optimum level of \(K\) should be. Of course, as we noted earlier, one should try multiple iterations to reach the optimum level of \(K\).
For now, let’s set \(K = 2\), where we choose to assign a state as “safe” and “not-so-safe”. We will do this by plotting the assigned clusters from the algorithm where we plot the features on the y-axis and the urban population on the x-axis. One thing to note when implementing K-means is that the clusters are numerical and therefore have no inherent meaning. It’s up to the user to interpret what this means. However, you can compare the means of the features based on clusters as we will soon see.
clusters <- kmeans(USArrests, centers = 2, iter.max = 1000, nstart = 50)
clusters
## K-means clustering with 2 clusters of sizes 29, 21
##
## Cluster means:
## Murder Assault UrbanPop Rape
## 1 4.841379 109.7586 64.03448 16.24828
## 2 11.857143 255.0000 67.61905 28.11429
##
## Clustering vector:
## Alabama Alaska Arizona Arkansas California
## 2 2 2 2 2
## Colorado Connecticut Delaware Florida Georgia
## 2 1 2 2 2
## Hawaii Idaho Illinois Indiana Iowa
## 1 1 2 1 1
## Kansas Kentucky Louisiana Maine Maryland
## 1 1 2 1 2
## Massachusetts Michigan Minnesota Mississippi Missouri
## 1 2 1 2 1
## Montana Nebraska Nevada New Hampshire New Jersey
## 1 1 2 1 1
## New Mexico New York North Carolina North Dakota Ohio
## 2 2 2 1 1
## Oklahoma Oregon Pennsylvania Rhode Island South Carolina
## 1 1 1 1 2
## South Dakota Tennessee Texas Utah Vermont
## 1 2 2 1 1
## Virginia Washington West Virginia Wisconsin Wyoming
## 1 1 1 1 1
##
## Within cluster sum of squares by cluster:
## [1] 54762.30 41636.73
## (between_SS / total_SS = 72.9 %)
##
## Available components:
##
## [1] "cluster" "centers" "totss" "withinss"
## [5] "tot.withinss" "betweenss" "size" "iter"
## [9] "ifault"
Based on the data set, the output is actually very telling. If you look at the Cluster Means, states belowing to cluster 1 have, on average, a significantly lower number of murders, assaults, and rapes given their populations. You can also see which state belongs to which cluster.
When dealing with different units, especially in which some features have large numbers while others have small, the K-means algorithm will generally perform better if we cale the feature variables. This can be done in a couple of ways, one of which is converting each feature to their z-scores using R’s \(\text{scale}()\) function.
Now, let’s visualize our resulting output.
library(ggplot2)
theme_set(theme_bw())
USArrests.scaled <- as.data.frame(apply(USArrests, MARGIN = 2, FUN = scale))
USArrests.scaled.clusters <- kmeans(USArrests.scaled, centers = 2, nstart = 50, iter.max = 5000)
USArrests.scaled$Cluster <- USArrests.scaled.clusters$cluster
USArrests$Cluster <- USArrests.scaled.clusters$cluster
# Factor Clusters
USArrests.scaled$Cluster <- factor(USArrests.scaled$Cluster, levels = c(1, 2), labels = c('Not-So-Safe State', 'Safe State'))
# Murders Against Population
murder <- ggplot(USArrests.scaled, aes(y = Murder, x = UrbanPop, shape = Cluster)) +
geom_point() +
stat_ellipse(aes(y = Murder, x = UrbanPop, fill = Cluster), geom = 'polygon', alpha = 0.21, level = 0.95)
assault <- ggplot(USArrests.scaled, aes(y = Assault, x = UrbanPop, shape = Cluster)) +
geom_point() +
stat_ellipse(aes(y = Assault, x = UrbanPop, fill = Cluster), geom = 'polygon', alpha = 0.21, level = 0.95)
rape <- ggplot(USArrests.scaled, aes(y = Rape, x = UrbanPop, col = Cluster, shape = Cluster)) +
geom_point() +
stat_ellipse(aes(y = Rape, x = UrbanPop, fill = Cluster), geom = 'polygon', alpha = 0.21, level = 0.95)
murder
assault
rape
Let’s see which states are “safe” and “not-so-safe” based on the number of murders, assaults, and rapes.
USArrests$Cluster <- factor(USArrests$Cluster, levels = c(1, 2), labels = c('Not-So-Safe', 'Safe'))
m <- kable(USArrests[,c(1, 5)], format = 'html') %>%
kable_styling(bootstrap_options = c('hover', 'striped'))
a <- kable(USArrests[, c(2, 5)], format = 'html') %>%
kable_styling(bootstrap_options = c('striped', 'hover'))
r <- kable(USArrests[, c(4, 5)], format = 'html') %>%
kable_styling(bootstrap_options = c('striped', 'hover'))
m
Murder | Cluster | |
---|---|---|
Alabama | 13.2 | Safe |
Alaska | 10.0 | Safe |
Arizona | 8.1 | Safe |
Arkansas | 8.8 | Not-So-Safe |
California | 9.0 | Safe |
Colorado | 7.9 | Safe |
Connecticut | 3.3 | Not-So-Safe |
Delaware | 5.9 | Not-So-Safe |
Florida | 15.4 | Safe |
Georgia | 17.4 | Safe |
Hawaii | 5.3 | Not-So-Safe |
Idaho | 2.6 | Not-So-Safe |
Illinois | 10.4 | Safe |
Indiana | 7.2 | Not-So-Safe |
Iowa | 2.2 | Not-So-Safe |
Kansas | 6.0 | Not-So-Safe |
Kentucky | 9.7 | Not-So-Safe |
Louisiana | 15.4 | Safe |
Maine | 2.1 | Not-So-Safe |
Maryland | 11.3 | Safe |
Massachusetts | 4.4 | Not-So-Safe |
Michigan | 12.1 | Safe |
Minnesota | 2.7 | Not-So-Safe |
Mississippi | 16.1 | Safe |
Missouri | 9.0 | Safe |
Montana | 6.0 | Not-So-Safe |
Nebraska | 4.3 | Not-So-Safe |
Nevada | 12.2 | Safe |
New Hampshire | 2.1 | Not-So-Safe |
New Jersey | 7.4 | Not-So-Safe |
New Mexico | 11.4 | Safe |
New York | 11.1 | Safe |
North Carolina | 13.0 | Safe |
North Dakota | 0.8 | Not-So-Safe |
Ohio | 7.3 | Not-So-Safe |
Oklahoma | 6.6 | Not-So-Safe |
Oregon | 4.9 | Not-So-Safe |
Pennsylvania | 6.3 | Not-So-Safe |
Rhode Island | 3.4 | Not-So-Safe |
South Carolina | 14.4 | Safe |
South Dakota | 3.8 | Not-So-Safe |
Tennessee | 13.2 | Safe |
Texas | 12.7 | Safe |
Utah | 3.2 | Not-So-Safe |
Vermont | 2.2 | Not-So-Safe |
Virginia | 8.5 | Not-So-Safe |
Washington | 4.0 | Not-So-Safe |
West Virginia | 5.7 | Not-So-Safe |
Wisconsin | 2.6 | Not-So-Safe |
Wyoming | 6.8 | Not-So-Safe |
a
Assault | Cluster | |
---|---|---|
Alabama | 236 | Safe |
Alaska | 263 | Safe |
Arizona | 294 | Safe |
Arkansas | 190 | Not-So-Safe |
California | 276 | Safe |
Colorado | 204 | Safe |
Connecticut | 110 | Not-So-Safe |
Delaware | 238 | Not-So-Safe |
Florida | 335 | Safe |
Georgia | 211 | Safe |
Hawaii | 46 | Not-So-Safe |
Idaho | 120 | Not-So-Safe |
Illinois | 249 | Safe |
Indiana | 113 | Not-So-Safe |
Iowa | 56 | Not-So-Safe |
Kansas | 115 | Not-So-Safe |
Kentucky | 109 | Not-So-Safe |
Louisiana | 249 | Safe |
Maine | 83 | Not-So-Safe |
Maryland | 300 | Safe |
Massachusetts | 149 | Not-So-Safe |
Michigan | 255 | Safe |
Minnesota | 72 | Not-So-Safe |
Mississippi | 259 | Safe |
Missouri | 178 | Safe |
Montana | 109 | Not-So-Safe |
Nebraska | 102 | Not-So-Safe |
Nevada | 252 | Safe |
New Hampshire | 57 | Not-So-Safe |
New Jersey | 159 | Not-So-Safe |
New Mexico | 285 | Safe |
New York | 254 | Safe |
North Carolina | 337 | Safe |
North Dakota | 45 | Not-So-Safe |
Ohio | 120 | Not-So-Safe |
Oklahoma | 151 | Not-So-Safe |
Oregon | 159 | Not-So-Safe |
Pennsylvania | 106 | Not-So-Safe |
Rhode Island | 174 | Not-So-Safe |
South Carolina | 279 | Safe |
South Dakota | 86 | Not-So-Safe |
Tennessee | 188 | Safe |
Texas | 201 | Safe |
Utah | 120 | Not-So-Safe |
Vermont | 48 | Not-So-Safe |
Virginia | 156 | Not-So-Safe |
Washington | 145 | Not-So-Safe |
West Virginia | 81 | Not-So-Safe |
Wisconsin | 53 | Not-So-Safe |
Wyoming | 161 | Not-So-Safe |
r
Rape | Cluster | |
---|---|---|
Alabama | 21.2 | Safe |
Alaska | 44.5 | Safe |
Arizona | 31.0 | Safe |
Arkansas | 19.5 | Not-So-Safe |
California | 40.6 | Safe |
Colorado | 38.7 | Safe |
Connecticut | 11.1 | Not-So-Safe |
Delaware | 15.8 | Not-So-Safe |
Florida | 31.9 | Safe |
Georgia | 25.8 | Safe |
Hawaii | 20.2 | Not-So-Safe |
Idaho | 14.2 | Not-So-Safe |
Illinois | 24.0 | Safe |
Indiana | 21.0 | Not-So-Safe |
Iowa | 11.3 | Not-So-Safe |
Kansas | 18.0 | Not-So-Safe |
Kentucky | 16.3 | Not-So-Safe |
Louisiana | 22.2 | Safe |
Maine | 7.8 | Not-So-Safe |
Maryland | 27.8 | Safe |
Massachusetts | 16.3 | Not-So-Safe |
Michigan | 35.1 | Safe |
Minnesota | 14.9 | Not-So-Safe |
Mississippi | 17.1 | Safe |
Missouri | 28.2 | Safe |
Montana | 16.4 | Not-So-Safe |
Nebraska | 16.5 | Not-So-Safe |
Nevada | 46.0 | Safe |
New Hampshire | 9.5 | Not-So-Safe |
New Jersey | 18.8 | Not-So-Safe |
New Mexico | 32.1 | Safe |
New York | 26.1 | Safe |
North Carolina | 16.1 | Safe |
North Dakota | 7.3 | Not-So-Safe |
Ohio | 21.4 | Not-So-Safe |
Oklahoma | 20.0 | Not-So-Safe |
Oregon | 29.3 | Not-So-Safe |
Pennsylvania | 14.9 | Not-So-Safe |
Rhode Island | 8.3 | Not-So-Safe |
South Carolina | 22.5 | Safe |
South Dakota | 12.8 | Not-So-Safe |
Tennessee | 26.9 | Safe |
Texas | 25.5 | Safe |
Utah | 22.9 | Not-So-Safe |
Vermont | 11.2 | Not-So-Safe |
Virginia | 20.7 | Not-So-Safe |
Washington | 26.2 | Not-So-Safe |
West Virginia | 9.3 | Not-So-Safe |
Wisconsin | 10.8 | Not-So-Safe |
Wyoming | 15.6 | Not-So-Safe |