library(networkD3)
library(dplyr)
links <- data.frame(
  source=c("Retirees", "TVs", "Actives", "Retirees", "TVs", "Actives"), 
  target=c("Benefits Funded", "Benefits Funded", "Benefits Funded", "Benefits Cut", "Benefits Cut", "Benefits Cut"), 
  value=c(6.45, 0.15, 0.5, 0.75, 1.55, 4.9),
  stringsAsFactors = FALSE
)

nodes <- data.frame(name=unique(c(links$source, links$target)),
                    stringsAsFactors = FALSE)

# id columns for networkD3 (0-based)
links$IDsource <- match(links$source, nodes$name) - 1 
links$IDtarget <- match(links$target, nodes$name) - 1

# -----------------------------
# Separate group labels:
# -----------------------------
# Node groups (distinct, non-overlapping with link groups)
nodes$ngroup <- dplyr::case_when(
  nodes$name == "Retirees" ~ "N_Retirees",
  nodes$name == "TVs" ~ "N_TVs",
  nodes$name == "Actives" ~ "N_Actives",
  nodes$name == "Benefits Funded" ~ "N_Funded",
  nodes$name == "Benefits Cut" ~ "N_Cut",
  TRUE ~ "N_Other"
)

# Link groups (distinct keys from node groups)
links$lgroup <- ifelse(links$target == "Benefits Funded",
                       "L_Funded", "L_Cut")

# -----------------------------
# One colour scale that includes BOTH node + link keys
# -----------------------------
colourScale <- '
d3.scaleOrdinal()
  .domain([
    "N_Retirees","N_TVs","N_Actives","N_Funded","N_Cut",
    "L_Funded","L_Cut"
  ])
  .range([
    "#7C3AED",  /* Retirees: blue    */
    "#2563EB",  /* TVs: amber        */
    "#F59E0B",  /* Actives: purple   */
    "#10B981",  /* Funded node: slate*/
    "#EF4444",  /* Cut node: light sl*/
    "#6EE7B7",  /* Funded flow: green*/
    "#FCA5A5"   /* Cut flow: red     */
  ]);
'

p <- sankeyNetwork(
  Links = links, Nodes = nodes,
  Source = "IDsource", Target = "IDtarget",
  Value = "value", NodeID = "name",
  NodeGroup = "ngroup",   # <- node colors
  LinkGroup = "lgroup",   # <- link colors
  fontSize = 12, sinksRight = FALSE,
  colourScale = colourScale
)

p