This notebook treats decision trees as segmentation tools rather than traditional clustering methods.
That means:
This file fits two segmentation trees:
HFRS_100HFRS_labelThe goal is to recover actionable rules such as:
library(here)
library(tidyverse)
library(rpart)
library(kableExtra)
library(scales)
set.seed(2024)
FIG_DIR <- here("02_clustering", "outputs", "figures")
DATA_DIR <- here("02_clustering", "outputs", "data")
dir.create(FIG_DIR, recursive = TRUE, showWarnings = FALSE)
dir.create(DATA_DIR, recursive = TRUE, showWarnings = FALSE)
TREE_VARS <- c(
"BBDI_sl_w_z",
"SHUSH_sl_w_z",
"MDI_sl_w_z",
"FINAS_sl_w_z",
"ODDI_sl_w_z",
"PWNETWPT_sl_w_z",
"PATTSTIN_z",
"PNBEARG_z"
)
wmean <- function(x, w) weighted.mean(x, w, na.rm = TRUE)
extract_leaf_profile <- function(df, leaf_col) {
df %>%
group_by(segment = .data[[leaf_col]]) %>%
summarise(
n = n(),
wt_n = sum(PWEIGHT, na.rm = TRUE),
HFRS_mean = wmean(HFRS_100, PWEIGHT),
pct_stressed = wmean(HFRS_label == "stressed", PWEIGHT) * 100,
pct_coping = wmean(HFRS_label == "coping", PWEIGHT) * 100,
pct_comfortable = wmean(HFRS_label == "comfortable", PWEIGHT) * 100,
across(all_of(TREE_VARS), ~ wmean(.x, PWEIGHT)),
.groups = "drop"
) %>%
arrange(desc(wt_n))
}
print_tree <- function(fit, main_title) {
if (requireNamespace("rpart.plot", quietly = TRUE)) {
if (fit$method == "class") {
rpart.plot::rpart.plot(
fit,
main = main_title,
type = 2,
extra = 104,
under = TRUE
)
} else {
rpart.plot::rpart.plot(
fit,
main = main_title,
type = 2,
under = TRUE
)
}
} else {
plot(fit, margin = 0.1)
text(fit, use.n = TRUE, cex = 0.7)
title(main = paste0(main_title, " (base plot)"))
}
}
df_raw <- readr::read_csv(here("data", "adf.csv"), show_col_types = FALSE)
df_tree <- df_raw %>%
filter(
!is.na(HFRS_100),
!is.na(HFRS_label),
!is.na(PWEIGHT),
if_all(all_of(TREE_VARS), ~ !is.na(.x))
) %>%
mutate(
HFRS_label = factor(HFRS_label, levels = c("stressed", "coping", "comfortable"))
)
cat("Rows used:", nrow(df_tree), "\n")
## Rows used: 24360
score_formula <- as.formula(
paste("HFRS_100 ~", paste(TREE_VARS, collapse = " + "))
)
tree_score <- rpart(
formula = score_formula,
data = df_tree,
weights = PWEIGHT,
method = "anova",
control = rpart.control(
minsplit = 300,
minbucket = 150,
cp = 0.002,
maxdepth = 5
)
)
printcp(tree_score)
##
## Regression tree:
## rpart(formula = score_formula, data = df_tree, weights = PWEIGHT,
## method = "anova", control = rpart.control(minsplit = 300,
## minbucket = 150, cp = 0.002, maxdepth = 5))
##
## Variables actually used in tree construction:
## [1] FINAS_sl_w_z MDI_sl_w_z PATTSTIN_z PWNETWPT_sl_w_z
## [5] SHUSH_sl_w_z
##
## Root node error: 5354226604/24360 = 219796
##
## n= 24360
##
## CP nsplit rel error xerror xstd
## 1 0.3193949 0 1.00000 1.00019 2.6412e-04
## 2 0.1197615 1 0.68061 0.68509 1.7458e-04
## 3 0.0946411 2 0.56084 0.56798 1.5893e-04
## 4 0.0471670 3 0.46620 0.46988 1.2049e-04
## 5 0.0448078 4 0.41904 0.42532 1.1206e-04
## 6 0.0366181 5 0.37423 0.37586 1.0231e-04
## 7 0.0203730 6 0.33761 0.34095 9.4832e-05
## 8 0.0132699 8 0.29686 0.30137 8.6265e-05
## 9 0.0114927 9 0.28359 0.29069 8.4105e-05
## 10 0.0108773 10 0.27210 0.28349 8.2463e-05
## 11 0.0099324 11 0.26122 0.26766 7.9896e-05
## 12 0.0092734 12 0.25129 0.25180 7.4383e-05
## 13 0.0091053 13 0.24202 0.24863 7.3100e-05
## 14 0.0062852 14 0.23291 0.23520 6.9783e-05
## 15 0.0056804 15 0.22663 0.22677 6.7684e-05
## 16 0.0051358 16 0.22095 0.22210 6.6387e-05
## 17 0.0050266 17 0.21581 0.22040 6.6099e-05
## 18 0.0038006 18 0.21078 0.21496 6.4231e-05
## 19 0.0034338 19 0.20698 0.21080 6.2773e-05
## 20 0.0033187 20 0.20355 0.20993 6.2689e-05
## 21 0.0031782 21 0.20023 0.20741 6.2163e-05
## 22 0.0020000 22 0.19705 0.20331 6.1421e-05
print_tree(tree_score, "Decision Tree Segmentation — HFRS_100")
df_tree <- df_tree %>%
mutate(score_segment = factor(tree_score$where))
score_segment_profiles <- extract_leaf_profile(df_tree, "score_segment")
score_segment_profiles %>%
mutate(wt_n = comma(round(wt_n))) %>%
kable(digits = 2, caption = "Regression-tree segments — weighted profiles") %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"), full_width = FALSE)
| segment | n | wt_n | HFRS_mean | pct_stressed | pct_coping | pct_comfortable | BBDI_sl_w_z | SHUSH_sl_w_z | MDI_sl_w_z | FINAS_sl_w_z | ODDI_sl_w_z | PWNETWPT_sl_w_z | PATTSTIN_z | PNBEARG_z |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 23 | 4299 | 4,598,544 | 51.94 | 6.23 | 58.11 | 35.66 | 0.26 | 0.05 | -0.49 | -0.49 | -0.15 | 0.42 | -0.37 | -0.18 |
| 10 | 2119 | 3,391,741 | 37.75 | 84.32 | 15.08 | 0.60 | -0.06 | -0.65 | -0.34 | -0.56 | 0.02 | 0.03 | 1.14 | 0.18 |
| 35 | 3131 | 2,794,452 | 58.84 | 2.28 | 22.95 | 74.77 | 0.32 | 0.25 | -0.55 | 0.29 | -0.18 | 0.15 | -0.66 | -0.20 |
| 21 | 1535 | 2,442,040 | 45.49 | 36.74 | 56.67 | 6.59 | -0.35 | -0.72 | -0.58 | -0.60 | -0.14 | -0.37 | -0.57 | -0.45 |
| 30 | 1406 | 1,746,686 | 48.03 | 18.17 | 66.82 | 15.01 | 0.03 | 0.16 | -0.52 | 0.36 | -0.03 | -0.08 | 1.14 | 0.10 |
| 37 | 1416 | 1,543,274 | 62.01 | 0.46 | 8.57 | 90.96 | 0.23 | 0.34 | -0.38 | 1.11 | -0.25 | 0.13 | -0.38 | -0.26 |
| 15 | 922 | 1,367,889 | 34.87 | 91.58 | 8.19 | 0.23 | -0.16 | -0.71 | 1.76 | -0.79 | 0.27 | 0.18 | -0.67 | 0.39 |
| 17 | 1027 | 1,294,503 | 46.42 | 28.99 | 61.41 | 9.59 | -0.43 | 1.05 | 1.40 | -0.44 | 0.36 | 0.33 | -0.70 | 0.39 |
| 24 | 1286 | 1,232,176 | 59.68 | 0.17 | 19.99 | 79.84 | 0.23 | 0.16 | -0.50 | -0.49 | -0.04 | 0.45 | -1.89 | 0.03 |
| 6 | 808 | 1,228,312 | 33.18 | 97.22 | 2.78 | 0.00 | -0.35 | 0.51 | 1.66 | -0.47 | 0.17 | 0.25 | 1.14 | 0.45 |
| 11 | 915 | 1,189,540 | 45.19 | 32.48 | 63.11 | 4.41 | -0.40 | 1.07 | -0.11 | -0.39 | 0.11 | 0.33 | 1.14 | 0.42 |
| 5 | 627 | 1,045,395 | 26.20 | 99.94 | 0.06 | 0.00 | 0.07 | -0.79 | 1.86 | -0.67 | 0.26 | 0.18 | 1.13 | 0.41 |
| 31 | 652 | 941,585 | 53.15 | 4.60 | 54.14 | 41.26 | 0.13 | 0.17 | -0.38 | 1.25 | -0.17 | -0.06 | 1.13 | -0.02 |
| 16 | 760 | 912,584 | 40.56 | 68.19 | 31.11 | 0.70 | 0.15 | -0.42 | 1.46 | -0.43 | 0.08 | 0.30 | -0.64 | 0.40 |
| 42 | 516 | 791,233 | 71.96 | 0.00 | 1.48 | 98.52 | 0.31 | 0.31 | -0.58 | 2.30 | -0.29 | 0.02 | -0.62 | -0.45 |
| 34 | 760 | 757,491 | 51.32 | 6.80 | 64.64 | 28.56 | -0.12 | 0.72 | 1.20 | 0.27 | 0.04 | 0.34 | -0.71 | 0.32 |
| 29 | 518 | 576,511 | 41.51 | 58.98 | 40.71 | 0.31 | -0.11 | 0.58 | 1.25 | 0.34 | 0.04 | 0.37 | 1.13 | 0.46 |
| 45 | 302 | 520,064 | 82.71 | 0.00 | 0.00 | 100.00 | 0.80 | 0.21 | -0.66 | 3.68 | -0.40 | 0.11 | -0.68 | -0.67 |
| 8 | 252 | 499,801 | 23.79 | 100.00 | 0.00 | 0.00 | -0.47 | -0.79 | -0.64 | -0.63 | 1.29 | -4.05 | 1.12 | -0.04 |
| 20 | 255 | 454,424 | 32.82 | 93.49 | 5.47 | 1.05 | -0.52 | -0.81 | -0.65 | -0.67 | 1.36 | -4.26 | -0.71 | -0.16 |
| 41 | 224 | 426,408 | 62.10 | 0.54 | 11.29 | 88.18 | 0.23 | 0.17 | -0.57 | 2.34 | -0.33 | 0.07 | 1.15 | -0.02 |
| 38 | 462 | 400,598 | 69.33 | 0.00 | 1.74 | 98.26 | 0.23 | 0.51 | -0.28 | 1.07 | -0.14 | 0.16 | -1.89 | 0.08 |
| 44 | 168 | 325,523 | 72.63 | 0.00 | 0.00 | 100.00 | 0.52 | 0.12 | -0.66 | 3.72 | -0.45 | 0.09 | 1.15 | -0.41 |
readr::write_csv(
score_segment_profiles,
file.path(DATA_DIR, "decision_tree_score_segments.csv")
)
label_formula <- as.formula(
paste("HFRS_label ~", paste(TREE_VARS, collapse = " + "))
)
tree_label <- rpart(
formula = label_formula,
data = df_tree,
weights = PWEIGHT,
method = "class",
control = rpart.control(
minsplit = 300,
minbucket = 150,
cp = 0.002,
maxdepth = 5
)
)
printcp(tree_label)
##
## Classification tree:
## rpart(formula = label_formula, data = df_tree, weights = PWEIGHT,
## method = "class", control = rpart.control(minsplit = 300,
## minbucket = 150, cp = 0.002, maxdepth = 5))
##
## Variables actually used in tree construction:
## [1] BBDI_sl_w_z FINAS_sl_w_z MDI_sl_w_z PATTSTIN_z
## [5] PWNETWPT_sl_w_z SHUSH_sl_w_z
##
## Root node error: 19809239/24360 = 813.19
##
## n= 24360
##
## CP nsplit rel error xerror xstd
## 1 0.3106578 0 1.00000 1.00000 0.00013294
## 2 0.0645168 1 0.68934 0.69318 0.00013867
## 3 0.0609996 2 0.62483 0.63984 0.00013736
## 4 0.0446075 3 0.56383 0.55843 0.00013401
## 5 0.0361270 4 0.51922 0.54062 0.00013305
## 6 0.0337603 5 0.48309 0.48984 0.00012983
## 7 0.0284701 6 0.44933 0.44031 0.00012596
## 8 0.0171723 7 0.42086 0.43281 0.00012531
## 9 0.0119892 8 0.40369 0.41051 0.00012327
## 10 0.0110119 9 0.39170 0.40304 0.00012254
## 11 0.0092049 11 0.36968 0.39415 0.00012166
## 12 0.0081240 13 0.35127 0.36621 0.00011869
## 13 0.0071334 14 0.34314 0.36260 0.00011828
## 14 0.0059503 15 0.33601 0.35873 0.00011784
## 15 0.0040179 16 0.33006 0.34690 0.00011646
## 16 0.0025294 17 0.32604 0.34261 0.00011595
## 17 0.0022594 18 0.32351 0.34170 0.00011584
## 18 0.0020298 20 0.31899 0.34183 0.00011586
## 19 0.0020000 21 0.31696 0.34171 0.00011584
print_tree(tree_label, "Decision Tree Segmentation — HFRS_label")
df_tree <- df_tree %>%
mutate(label_segment = factor(tree_label$where))
label_segment_profiles <- extract_leaf_profile(df_tree, "label_segment")
label_segment_profiles %>%
mutate(wt_n = comma(round(wt_n))) %>%
kable(digits = 2, caption = "Classification-tree segments — weighted profiles") %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"), full_width = FALSE)
| segment | n | wt_n | HFRS_mean | pct_stressed | pct_coping | pct_comfortable | BBDI_sl_w_z | SHUSH_sl_w_z | MDI_sl_w_z | FINAS_sl_w_z | ODDI_sl_w_z | PWNETWPT_sl_w_z | PATTSTIN_z | PNBEARG_z |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 43 | 5577 | 5,550,800 | 65.05 | 0.00 | 9.05 | 90.95 | 0.45 | 0.30 | -0.63 | 1.01 | -0.30 | 0.29 | -0.67 | -0.31 |
| 21 | 4538 | 5,313,340 | 49.80 | 13.23 | 63.80 | 22.97 | 0.26 | -0.50 | -0.49 | -0.59 | -0.14 | 0.27 | -0.66 | -0.16 |
| 5 | 2609 | 4,456,447 | 32.18 | 97.05 | 2.95 | 0.00 | -0.37 | -0.63 | 0.28 | -0.63 | 0.19 | -0.36 | 1.13 | 0.27 |
| 31 | 2010 | 2,450,157 | 49.79 | 10.11 | 70.45 | 19.44 | 0.10 | 0.28 | -0.47 | 0.45 | -0.13 | 0.21 | 1.14 | 0.16 |
| 15 | 1389 | 1,943,744 | 35.99 | 90.81 | 9.07 | 0.12 | -0.09 | -0.49 | 1.65 | -0.65 | 0.20 | 0.23 | -0.36 | 0.40 |
| 22 | 1443 | 1,499,993 | 57.18 | 0.00 | 31.20 | 68.80 | -0.42 | 1.31 | -0.42 | -0.46 | 0.06 | 0.38 | -0.71 | -0.11 |
| 19 | 681 | 1,221,364 | 38.91 | 76.88 | 22.13 | 0.99 | -0.53 | -0.93 | -0.65 | -0.67 | 0.29 | -2.04 | -0.61 | -0.53 |
| 32 | 587 | 1,063,095 | 63.25 | 1.42 | 15.63 | 82.95 | 0.30 | 0.12 | -0.56 | 2.54 | -0.34 | 0.01 | 1.15 | -0.16 |
| 10 | 667 | 937,336 | 37.22 | 84.47 | 15.53 | 0.00 | -0.44 | 1.02 | 1.23 | -0.47 | 0.17 | 0.26 | 1.15 | 0.58 |
| 36 | 928 | 929,176 | 49.71 | 9.55 | 72.99 | 17.47 | -0.07 | 0.40 | 1.02 | 0.13 | -0.05 | 0.37 | -0.66 | 0.44 |
| 17 | 635 | 787,511 | 47.19 | 23.53 | 66.73 | 9.74 | -0.45 | 1.28 | 1.45 | -0.49 | 0.46 | 0.33 | -0.75 | 0.35 |
| 11 | 470 | 630,230 | 46.72 | 17.72 | 75.59 | 6.69 | -0.38 | 1.06 | -0.60 | -0.47 | 0.22 | 0.12 | 1.14 | 0.22 |
| 38 | 580 | 558,586 | 59.16 | 0.41 | 23.02 | 76.57 | -0.22 | 1.00 | 0.92 | 1.00 | -0.01 | 0.43 | -0.77 | 0.33 |
| 27 | 479 | 531,617 | 38.51 | 83.42 | 16.58 | 0.00 | -0.03 | 0.12 | 1.34 | 0.26 | -0.06 | 0.34 | 1.13 | 0.48 |
| 16 | 377 | 474,499 | 43.30 | 44.27 | 52.74 | 2.99 | -0.08 | -0.45 | 1.69 | -0.65 | 0.29 | 0.26 | -1.87 | 0.38 |
| 8 | 279 | 401,834 | 44.96 | 37.51 | 59.79 | 2.70 | 1.64 | -0.71 | -0.67 | -0.53 | 0.29 | 0.22 | 1.15 | -0.32 |
| 30 | 168 | 341,780 | 38.45 | 69.65 | 28.31 | 2.05 | -0.32 | -0.62 | -0.66 | 0.41 | 0.57 | -2.03 | 1.13 | -0.23 |
| 41 | 156 | 328,433 | 47.17 | 20.58 | 71.84 | 7.58 | -0.29 | -0.73 | -0.67 | 0.18 | 0.31 | -1.60 | -0.48 | -0.61 |
| 7 | 213 | 307,738 | 31.32 | 98.15 | 1.85 | 0.00 | 1.40 | -0.74 | 1.78 | -0.57 | 0.27 | 0.38 | 1.14 | 0.32 |
| 42 | 150 | 289,166 | 56.88 | 6.78 | 37.94 | 55.27 | -0.07 | -0.57 | -0.67 | 1.41 | 0.16 | -1.64 | -0.59 | -0.65 |
| 28 | 207 | 251,269 | 45.82 | 24.82 | 71.73 | 3.45 | -0.50 | 1.75 | 1.38 | 0.40 | 0.24 | 0.38 | 1.13 | 0.36 |
| 37 | 217 | 212,662 | 56.76 | 1.31 | 30.63 | 68.06 | -0.51 | 2.17 | 1.07 | 0.16 | 0.36 | 0.33 | -0.66 | 0.16 |
readr::write_csv(
label_segment_profiles,
file.path(DATA_DIR, "decision_tree_label_segments.csv")
)
importance_tbl <- tibble(
variable = names(tree_label$variable.importance),
importance = as.numeric(tree_label$variable.importance)
) %>%
mutate(importance = importance / sum(importance)) %>%
arrange(desc(importance))
importance_tbl %>%
kable(digits = 3, caption = "Classification tree — normalized variable importance") %>%
kable_styling(full_width = FALSE)
| variable | importance |
|---|---|
| FINAS_sl_w_z | 0.324 |
| PATTSTIN_z | 0.237 |
| MDI_sl_w_z | 0.157 |
| SHUSH_sl_w_z | 0.136 |
| PWNETWPT_sl_w_z | 0.108 |
| BBDI_sl_w_z | 0.027 |
| PNBEARG_z | 0.006 |
| ODDI_sl_w_z | 0.005 |
readr::write_csv(
importance_tbl,
file.path(DATA_DIR, "decision_tree_variable_importance.csv")
)
cat("
- This is a **supervised segmentation** method, not unsupervised clustering.
- That means the target (`HFRS_100` or `HFRS_label`) directly shapes the groups.
- The advantage is interpretability: you get explicit split rules that can be turned into actionable segment descriptions.
- The tradeoff is that these segments should not be described as \"natural clusters\".
- In this project, decision trees are most useful when the research goal is **actionable groups** rather than latent structure.
")
HFRS_100 or
HFRS_label) directly shapes the groups.saveRDS(df_tree, file.path(DATA_DIR, "decision_tree_segmented_adf.rds"))
cat("Saved: outputs/data/decision_tree_segmented_adf.rds\n")
## Saved: outputs/data/decision_tree_segmented_adf.rds
sessionInfo()
## R version 4.3.2 (2023-10-31)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS 26.2
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.11.0
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## time zone: America/Toronto
## tzcode source: internal
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] scales_1.4.0 kableExtra_1.4.0 rpart_4.1.21 lubridate_1.9.3
## [5] forcats_1.0.0 stringr_1.5.1 dplyr_1.1.4 purrr_1.0.4
## [9] readr_2.1.5 tidyr_1.3.1 tibble_3.2.1 ggplot2_3.5.2
## [13] tidyverse_2.0.0 here_1.0.1
##
## loaded via a namespace (and not attached):
## [1] sass_0.4.10 utf8_1.2.4 generics_0.1.3 xml2_1.3.6
## [5] stringi_1.8.3 hms_1.1.3 digest_0.6.34 magrittr_2.0.3
## [9] evaluate_0.23 grid_4.3.2 timechange_0.3.0 RColorBrewer_1.1-3
## [13] fastmap_1.2.0 rprojroot_2.0.4 jsonlite_1.8.8 rpart.plot_3.1.2
## [17] fansi_1.0.6 viridisLite_0.4.2 textshaping_0.3.7 jquerylib_0.1.4
## [21] cli_3.6.5 crayon_1.5.2 rlang_1.1.6 bit64_4.0.5
## [25] withr_3.0.2 cachem_1.1.0 yaml_2.3.8 parallel_4.3.2
## [29] tools_4.3.2 tzdb_0.5.0 vctrs_0.6.5 R6_2.5.1
## [33] lifecycle_1.0.4 bit_4.0.5 vroom_1.6.5 pkgconfig_2.0.3
## [37] pillar_1.9.0 bslib_0.6.1 gtable_0.3.6 glue_1.8.0
## [41] systemfonts_1.2.3 highr_0.11 xfun_0.52 tidyselect_1.2.1
## [45] rstudioapi_0.15.0 knitr_1.48 farver_2.1.1 htmltools_0.5.8.1
## [49] rmarkdown_2.25 svglite_2.2.1 compiler_4.3.2