title: “my k-means clustering” author: “bpfisher” date: “2025-11-06” output: html_document
# K-means算法:不使用R内置的kmeans函数
my_kmeans <- function(data, k, max_iter = 100, tol = 1e-4, seed = NULL) {
#设置随机种子
if (!is.null(seed)) {
set.seed(seed)
}
#数据预处理
data <- as.matrix(data)
n <- nrow(data) # 样本数
p <- ncol(data) # 特征数
#随机选择初始中心点
centers <- data[sample(1:n, k), ]
#初始化变量
clusters <- integer(n) # 存储每个样本的簇标签
prev_centers <- centers
iter <- 0
converged <- FALSE
#迭代过程
for (iter in 1:max_iter) {
#将每个样本分配到最近的中心
distances <- matrix(0, n, k)
for (i in 1:k) {
#计算每个样本到第i个中心的欧氏距离
diff <- sweep(data, 2, centers[i, ], "-")
distances[, i] <- sqrt(rowSums(diff^2))
}
#分配簇标签(选择距离最小的中心)
clusters <- apply(distances, 1, which.min)
#更新中心点
prev_centers <- centers
for (i in 1:k) {
if (sum(clusters == i) > 0) {
centers[i, ] <- colMeans(data[clusters == i, , drop = FALSE])
}
}
#检查收敛:中心点变化是否小于阈值
center_shift <- sqrt(sum((centers - prev_centers)^2))
if (center_shift < tol) {
converged <- TRUE
break
}
}
#计算簇内平方和
within_ss <- 0
for (i in 1:k) {
cluster_data <- data[clusters == i, , drop = FALSE]
if (nrow(cluster_data) > 0) {
cluster_center <- centers[i, ]
diff <- sweep(cluster_data, 2, cluster_center, "-")
within_ss <- within_ss + sum(diff^2)
}
}
#计算总平方和
total_center <- colMeans(data)
diff_total <- sweep(data, 2, total_center, "-")
total_ss <- sum(diff_total^2)
#计算簇间平方和
between_ss <- total_ss - within_ss
#返回结果
result <- list(
clusters = clusters,
centers = centers,
withinss = within_ss,
tot.withinss = within_ss,
betweenss = between_ss,
totss = total_ss,
size = as.vector(table(clusters)),
iter = iter,
converged = converged
)
return(result)
}
#计算CH指数函数
calculate_ch_index <- function(kmeans_result, data, k) {
n <- nrow(data)
between_ss <- kmeans_result$betweenss
within_ss <- kmeans_result$tot.withinss
if (k == 1) {
return(0)
}
ch_index <- (between_ss / (k - 1)) / (within_ss / (n - k))
return(ch_index)
}
# 加载必要包
library(ggplot2)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
# 使用USArrests数据集
data(USArrests)
usarrests_data <- USArrests
set.seed(123)
k <- 4
kmeans_arrests <- my_kmeans(usarrests_data, k = k, max_iter = 100)
ch_index <- calculate_ch_index(kmeans_arrests, usarrests_data, k)
cat("USArrests数据集聚类结果:\n")
## USArrests数据集聚类结果:
cat("迭代次数:", kmeans_arrests$iter, "\n")
## 迭代次数: 3
cat("簇大小:", kmeans_arrests$size, "\n")
## 簇大小: 12 10 21 7
cat("总簇内平方和:", kmeans_arrests$tot.withinss, "\n")
## 总簇内平方和: 39641.35
cat("CH指数:", ch_index, "\n")
## CH指数: 122.2937
# 准备绘图数据
arrests_plot_data <- data.frame(
Murder = usarrests_data$Murder,
Assault = usarrests_data$Assault,
Cluster = as.factor(kmeans_arrests$clusters),
State = rownames(usarrests_data)
)
arrests_plot_data$Cluster <- factor(arrests_plot_data$Cluster,
levels = 1:k,
labels = paste("Cluster", 1:k))
# 计算聚类中心
cluster_centers <- arrests_plot_data %>%
group_by(Cluster) %>%
summarize(
Murder = mean(Murder),
Assault = mean(Assault)
)
# 创建图形
p <- ggplot(arrests_plot_data, aes(x = Murder, y = Assault, color = Cluster)) +
geom_point(size = 4, alpha = 0.9) +
geom_point(data = cluster_centers,
aes(x = Murder, y = Assault, color = Cluster),
shape = 17, size = 6, stroke = 2, color = "black") + # 先画黑色边框
geom_point(data = cluster_centers,
aes(x = Murder, y = Assault, color = Cluster),
shape = 17, size = 5) +
scale_color_manual(
name = "Cluster",
values = c("Cluster 1" = "#e8c559", "Cluster 2" = "#99cbeb",
"Cluster 3" = "#db6968", "Cluster 4" = "#a3d393")
) +
labs(
title = paste0("K-means (k=", k, ") Final Clustering | ",
"Iterations: ", kmeans_arrests$iter, " | ",
"WCSS: ", round(kmeans_arrests$tot.withinss, 1), " | ",
"CH: ", round(ch_index, 3)),
x = "Murder Rate",
y = "Assault Rate"
) +
theme_bw() +
theme(
plot.title = element_text(
hjust = 0.5,
size = 14,
face = "bold",
margin = margin(b = 10)
),
panel.grid.major = element_line(color = "grey90", size = 0.5),
panel.grid.minor = element_blank(),
legend.position = "right"
)
## Warning: The `size` argument of `element_line()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
# 显示图形
p
# 保存图形
ggsave("kmeans_homework1.png", plot = p, width = 12, height = 8, dpi = 300, bg = "white")