# =============================================================================
# 3D CCA Cluster Plot (Darker Axes + Stronger 3D View)
# =============================================================================

# Load libraries
library(readxl)
## Warning: package 'readxl' was built under R version 4.5.3
library(dplyr)
## Warning: package 'dplyr' was built under R version 4.5.3
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(clustrd)
## Warning: package 'clustrd' was built under R version 4.5.3
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 4.5.3
## Loading required package: grid
library(plotly)
## Warning: package 'plotly' was built under R version 4.5.2
## 
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
## 
##     last_plot
## The following object is masked from 'package:stats':
## 
##     filter
## The following object is masked from 'package:graphics':
## 
##     layout
library(htmlwidgets)

# =============================================================================
# 1. SET WORKING DIRECTORY
# =============================================================================

setwd("D:/TXST-onedrive/OneDrive - Texas State University/00_Spring 2026/Das Highway Safety/Final Project/0_Analysis/2_CCA/3D plot trying")

# =============================================================================
# 2. LOAD DATA
# =============================================================================

data <- read_excel("Final Dataset_V1.xlsx")

# Optional check
print(names(data))
##  [1] "Wthr"         "Light"        "RoadAlgn"     "SurfCond"     "HarmEvnt"    
##  [6] "IntrsctRelat" "VehBody"      "PrsnInjrySev" "PrsnAge"      "PrsnGndr"    
## [11] "PrsnBacTest"
# =============================================================================
# 3. SELECT VARIABLES
# =============================================================================

data_clean <- data %>%
  select(
    Wthr,
    Light,
    RoadAlgn,
    SurfCond,
    HarmEvnt,
    IntrsctRelat,
    VehBody,
    PrsnInjrySev,
    PrsnAge,
    PrsnGndr,
    PrsnBacTest
  ) %>%
  na.omit() %>%
  mutate(across(everything(), as.factor))

cat("Rows:", nrow(data_clean), "\n")
## Rows: 49346
cat("Columns:", ncol(data_clean), "\n")
## Columns: 11
# =============================================================================
# 4. RUN CLUSMCA IN 3 DIMENSIONS
# =============================================================================

set.seed(1234)

res <- clusmca(
  data_clean,
  nclus = 4,   # change if needed
  ndim = 3,
  method = "clusCA",
  nstart = 10
)
##   |                                                                              |                                                                      |   0%  |                                                                              |=======                                                               |  10%  |                                                                              |==============                                                        |  20%  |                                                                              |=====================                                                 |  30%  |                                                                              |============================                                          |  40%  |                                                                              |===================================                                   |  50%  |                                                                              |==========================================                            |  60%  |                                                                              |=================================================                     |  70%  |                                                                              |========================================================              |  80%  |                                                                              |===============================================================       |  90%  |                                                                              |======================================================================| 100%
print(res)
## Solution with 4 clusters of sizes 23545 (47.7%), 15076 (30.6%), 6028 (12.2%), 4697 (9.5%) in 3 dimensions. 
## 
## Cluster centroids:
##             Dim.1   Dim.2   Dim.3
## Cluster 1  0.0010  0.0000  0.0028
## Cluster 2  0.0014 -0.0031 -0.0030
## Cluster 3  0.0013  0.0076 -0.0031
## Cluster 4 -0.0110 -0.0001 -0.0005
## 
## Within cluster sum of squares by cluster:
## [1] 0.0571 0.0592 0.1268 0.1296
##  (between_SS / total_SS =  80.15 %) 
## 
## Objective criterion value: 3.7272 
## 
## Available output:
## 
## [1] "obscoord"  "attcoord"  "centroid"  "cluster"   "criterion" "size"     
## [7] "odata"     "nstart"
# =============================================================================
# 5. PREPARE DATA FOR 3D PLOT
# =============================================================================

plot_data <- data.frame(
  Dim1 = res$obscoord[, 1],
  Dim2 = res$obscoord[, 2],
  Dim3 = res$obscoord[, 3],
  Cluster = factor(paste0("C", res$cluster))
)

centroids <- plot_data %>%
  group_by(Cluster) %>%
  summarise(
    Dim1 = mean(Dim1, na.rm = TRUE),
    Dim2 = mean(Dim2, na.rm = TRUE),
    Dim3 = mean(Dim3, na.rm = TRUE),
    .groups = "drop"
  )

# =============================================================================
# 6. DEFINE COLORS
# =============================================================================

cluster_colors <- c(
  "C1" = "#E41A1C",
  "C2" = "#377EB8",
  "C3" = "#4DAF4A",
  "C4" = "#984EA3",
  "C5" = "#FF7F00",
  "C6" = "#A65628"
)

cluster_colors <- cluster_colors[unique(plot_data$Cluster)]

# =============================================================================
# 7. BUILD 3D PLOT
# =============================================================================

p <- plot_ly()

for (cl in unique(plot_data$Cluster)) {
  
  df_sub <- plot_data %>% filter(Cluster == cl)
  
  p <- p %>%
    add_trace(
      data = df_sub,
      x = ~Dim1,
      y = ~Dim2,
      z = ~Dim3,
      type = "scatter3d",
      mode = "markers",
      name = cl,
      marker = list(
        size = 3.5,
        color = cluster_colors[cl],
        opacity = 0.75,
        line = list(color = "rgba(40,40,40,0.5)", width = 0.3)
      ),
      hovertemplate = paste0(
        "Cluster: ", cl,
        "<br>Dim 1: %{x:.3f}",
        "<br>Dim 2: %{y:.3f}",
        "<br>Dim 3: %{z:.3f}<extra></extra>"
      )
    )
}

# Add centroids
p <- p %>%
  add_trace(
    data = centroids,
    x = ~Dim1,
    y = ~Dim2,
    z = ~Dim3,
    type = "scatter3d",
    mode = "markers+text",
    text = ~Cluster,
    textposition = "top center",
    name = "Centroids",
    marker = list(
      size = 9,
      color = "black",
      symbol = "diamond",
      line = list(color = "white", width = 1)
    ),
    textfont = list(
      size = 12,
      color = "black"
    ),
    hovertemplate = paste(
      "Centroid",
      "<br>Cluster: %{text}",
      "<br>Dim 1: %{x:.3f}",
      "<br>Dim 2: %{y:.3f}",
      "<br>Dim 3: %{z:.3f}<extra></extra>"
    )
  )

# =============================================================================
# 8. LAYOUT: DARKER AXES + STRONGER 3D LOOK
# =============================================================================

p <- p %>%
  layout(
    title = list(
      text = "3D CCA Cluster Plot",
      font = list(size = 22, color = "black")
    ),
    scene = list(
      xaxis = list(
        title = "Dim 1",
        titlefont = list(size = 14, color = "black"),
        tickfont = list(size = 11, color = "black"),
        showbackground = TRUE,
        backgroundcolor = "rgb(245,245,245)",
        gridcolor = "rgb(190,190,190)",
        zerolinecolor = "rgb(60,60,60)",
        linecolor = "rgb(40,40,40)",
        linewidth = 4,
        showline = TRUE,
        ticks = "outside",
        tickcolor = "rgb(40,40,40)",
        tickwidth = 2
      ),
      yaxis = list(
        title = "Dim 2",
        titlefont = list(size = 14, color = "black"),
        tickfont = list(size = 11, color = "black"),
        showbackground = TRUE,
        backgroundcolor = "rgb(245,245,245)",
        gridcolor = "rgb(190,190,190)",
        zerolinecolor = "rgb(60,60,60)",
        linecolor = "rgb(40,40,40)",
        linewidth = 4,
        showline = TRUE,
        ticks = "outside",
        tickcolor = "rgb(40,40,40)",
        tickwidth = 2
      ),
      zaxis = list(
        title = "Dim 3",
        titlefont = list(size = 14, color = "black"),
        tickfont = list(size = 11, color = "black"),
        showbackground = TRUE,
        backgroundcolor = "rgb(245,245,245)",
        gridcolor = "rgb(190,190,190)",
        zerolinecolor = "rgb(60,60,60)",
        linecolor = "rgb(40,40,40)",
        linewidth = 4,
        showline = TRUE,
        ticks = "outside",
        tickcolor = "rgb(40,40,40)",
        tickwidth = 2
      ),
      aspectmode = "cube",
      camera = list(
        eye = list(x = 1.8, y = 1.6, z = 1.35)
      )
    ),
    legend = list(
      title = list(text = "Clusters"),
      font = list(size = 12, color = "black")
    )
  )

# Show plot
p
# =============================================================================
# 9. SAVE OUTPUT
# =============================================================================

saveWidget(p, "CCA_3D_cluster_plot_darker_axes.html", selfcontained = TRUE)

cat("Saved: CCA_3D_cluster_plot_darker_axes.html\n")
## Saved: CCA_3D_cluster_plot_darker_axes.html