title: “my k-means clustering” author: “bpfisher” date: “2025-11-06” output: html_document

my K-means算法

算法实现

# K-means算法:不使用R内置的kmeans函数
my_kmeans <- function(data, k, max_iter = 100, tol = 1e-4, seed = NULL) {
  
  #设置随机种子
  if (!is.null(seed)) {
    set.seed(seed)
  }
  
  #数据预处理
  data <- as.matrix(data)
  n <- nrow(data)  # 样本数
  p <- ncol(data)  # 特征数
  
  #随机选择初始中心点
  centers <- data[sample(1:n, k), ]
  
  #初始化变量
  clusters <- integer(n)  # 存储每个样本的簇标签
  prev_centers <- centers
  iter <- 0
  converged <- FALSE
  
  #迭代过程
  for (iter in 1:max_iter) {
    
    #将每个样本分配到最近的中心
    distances <- matrix(0, n, k)
    for (i in 1:k) {
      #计算每个样本到第i个中心的欧氏距离
      diff <- sweep(data, 2, centers[i, ], "-")
      distances[, i] <- sqrt(rowSums(diff^2))
    }
    
    #分配簇标签(选择距离最小的中心)
    clusters <- apply(distances, 1, which.min)
    
    #更新中心点
    prev_centers <- centers
    for (i in 1:k) {
      if (sum(clusters == i) > 0) {
        centers[i, ] <- colMeans(data[clusters == i, , drop = FALSE])
      }
    }
    
    #检查收敛:中心点变化是否小于阈值
    center_shift <- sqrt(sum((centers - prev_centers)^2))
    if (center_shift < tol) {
      converged <- TRUE
      break
    }
  }
  
  #计算簇内平方和
  within_ss <- 0
  for (i in 1:k) {
    cluster_data <- data[clusters == i, , drop = FALSE]
    if (nrow(cluster_data) > 0) {
      cluster_center <- centers[i, ]
      diff <- sweep(cluster_data, 2, cluster_center, "-")
      within_ss <- within_ss + sum(diff^2)
    }
  }
  
  #计算总平方和
  total_center <- colMeans(data)
  diff_total <- sweep(data, 2, total_center, "-")
  total_ss <- sum(diff_total^2)
  
  #计算簇间平方和
  between_ss <- total_ss - within_ss
  
  #返回结果
  result <- list(
    clusters = clusters,
    centers = centers,
    withinss = within_ss,
    tot.withinss = within_ss,
    betweenss = between_ss,
    totss = total_ss,
    size = as.vector(table(clusters)),
    iter = iter,
    converged = converged
  )
  
  return(result)
}

辅助函数

#计算CH指数函数
calculate_ch_index <- function(kmeans_result, data, k) {
  n <- nrow(data)
  between_ss <- kmeans_result$betweenss
  within_ss <- kmeans_result$tot.withinss
  
  if (k == 1) {
    return(0)
  }
  
  ch_index <- (between_ss / (k - 1)) / (within_ss / (n - k))
  return(ch_index)
}

算法测试

# 加载必要包
library(ggplot2)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
# 使用USArrests数据集
data(USArrests)
usarrests_data <- USArrests

set.seed(123)
k <- 4
kmeans_arrests <- my_kmeans(usarrests_data, k = k, max_iter = 100)

ch_index <- calculate_ch_index(kmeans_arrests, usarrests_data, k)

cat("USArrests数据集聚类结果:\n")
## USArrests数据集聚类结果:
cat("迭代次数:", kmeans_arrests$iter, "\n")
## 迭代次数: 3
cat("簇大小:", kmeans_arrests$size, "\n")
## 簇大小: 12 10 21 7
cat("总簇内平方和:", kmeans_arrests$tot.withinss, "\n")
## 总簇内平方和: 39641.35
cat("CH指数:", ch_index, "\n")
## CH指数: 122.2937

结果可视化

# 准备绘图数据
arrests_plot_data <- data.frame(
  Murder = usarrests_data$Murder,
  Assault = usarrests_data$Assault,
  Cluster = as.factor(kmeans_arrests$clusters),
  State = rownames(usarrests_data)
)

arrests_plot_data$Cluster <- factor(arrests_plot_data$Cluster, 
                                    levels = 1:k,
                                    labels = paste("Cluster", 1:k))

# 计算聚类中心
cluster_centers <- arrests_plot_data %>%
  group_by(Cluster) %>%
  summarize(
    Murder = mean(Murder),
    Assault = mean(Assault)
  )

# 创建图形
p <- ggplot(arrests_plot_data, aes(x = Murder, y = Assault, color = Cluster)) +
  geom_point(size = 4, alpha = 0.9) +
  geom_point(data = cluster_centers, 
             aes(x = Murder, y = Assault, color = Cluster),
             shape = 17, size = 6, stroke = 2, color = "black") +  # 先画黑色边框
  geom_point(data = cluster_centers, 
             aes(x = Murder, y = Assault, color = Cluster),
             shape = 17, size = 5) + 
  scale_color_manual(
    name = "Cluster",
    values = c("Cluster 1" = "#e8c559", "Cluster 2" = "#99cbeb", 
               "Cluster 3" = "#db6968", "Cluster 4" = "#a3d393")
  ) +
  labs(
    title = paste0("K-means (k=", k, ") Final Clustering | ", 
                   "Iterations: ", kmeans_arrests$iter, " | ",
                   "WCSS: ", round(kmeans_arrests$tot.withinss, 1), " | ",
                   "CH: ", round(ch_index, 3)),
    x = "Murder Rate",
    y = "Assault Rate"
  ) +
  theme_bw() +
  theme(
    plot.title = element_text(
      hjust = 0.5, 
      size = 14, 
      face = "bold",
      margin = margin(b = 10)
    ),
    panel.grid.major = element_line(color = "grey90", size = 0.5),
    panel.grid.minor = element_blank(),
    legend.position = "right"
  )
## Warning: The `size` argument of `element_line()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
# 显示图形
p

# 保存图形
ggsave("kmeans_homework1.png", plot = p, width = 12, height = 8, dpi = 300, bg = "white")