library(data.table)
library(xgboost)
library(glmnet)
library(SHAPforxgboost)
library(ggplot2)
library(gridExtra)

OUTPUT_DIR <- "/Users/amalianimeskern/Library/CloudStorage/OneDrive-ErasmusUniversityRotterdam/Freddie Mac Data"

# --- Load models and test data ---
xgb_model      <- readRDS(file.path(OUTPUT_DIR, "xgb_model.rds"))
logistic_model <- readRDS(file.path(OUTPUT_DIR, "logistic_model.rds"))
best_lambda    <- readRDS(file.path(OUTPUT_DIR, "best_lambda.rds"))

test_xgb <- readRDS(file.path(OUTPUT_DIR, "test_xgb.rds"))
test_woe <- readRDS(file.path(OUTPUT_DIR, "test_woe.rds"))
train_xgb <- readRDS(file.path(OUTPUT_DIR, "train_xgb.rds"))
train_woe <- readRDS(file.path(OUTPUT_DIR, "train_woe.rds"))

# --- Feature names ---
xgb_features <- setdiff(names(test_xgb),
                        c("loan_sequence_number", "monthly_reporting_period",
                          "default_next_12m"))
woe_features <- setdiff(names(test_woe),
                        c("loan_sequence_number", "monthly_reporting_period",
                          "default_next_12m"))


# --- Subsample for explanations ---
# A common random subsample of 10,000 test observations is used for both models to ensure that the analysis is comparable

X_test_xgb <- as.matrix(test_xgb[, ..xgb_features])
X_test_woe <- as.matrix(test_woe[, ..woe_features])

set.seed(42)
shap_idx   <- sample(nrow(X_test_xgb), 10000)

X_shap_xgb <- X_test_xgb[shap_idx, ]
X_shap_woe <- X_test_woe[shap_idx, ]


# --- Xgboost SHAP (TreeSHAP) ---

shap_xgb <- shap.values(xgb_model = xgb_model, X_train = X_shap_xgb)

# SHAP value matrix and mean absolute SHAP
shap_values_xgb <- shap_xgb$shap_score
mean_shap_xgb   <- shap_xgb$mean_shap_score

# Global feature importance ranking
importance_xgb <- data.table(
  feature = names(mean_shap_xgb),
  mean_abs_shap = as.numeric(mean_shap_xgb)
)[order(-mean_abs_shap)]
importance_xgb[, rank := .I]
importance_xgb


# --- Logistic SHAP (analytical, on subsample for comparability) ---

# SHAP for linear model (Lundberg & Lee 2017, Corollary 1)
# phi_i = beta_i * (x_i - E[x_i])
coefs <- as.numeric(coef(logistic_model, s = best_lambda))[-1]
names(coefs) <- woe_features

# Training set means for centering
X_train_woe <- as.matrix(train_woe[, ..woe_features])
train_means  <- colMeans(X_train_woe, na.rm = TRUE)

# Compute SHAP on the SAME 10,000 subsample as XGBoost
shap_values_logistic <- sweep(X_shap_woe, 2, train_means, "-")
shap_values_logistic <- sweep(shap_values_logistic, 2, coefs, "*")
shap_values_logistic <- as.data.table(shap_values_logistic)

# Mean absolute SHAP for logistic (on the same subsample)
mean_shap_logistic <- sort(colMeans(abs(shap_values_logistic)),
                           decreasing = TRUE)
importance_logistic <- data.table(
  feature = names(mean_shap_logistic),
  mean_abs_shap = as.numeric(mean_shap_logistic)
)[order(-mean_abs_shap)]
importance_logistic[, rank := .I]
importance_logistic


# --- Feature Importance Rankings ---

# Aggregate one-hot encoded XGBoost features to base variable level
categorical_bases <- c("channel", "loan_purpose", "occupancy_status",
                       "property_type", "first_time_homebuyer")

importance_xgb_agg <- copy(importance_xgb)
importance_xgb_agg[, base_feature := feature]
for (base in categorical_bases) {
  importance_xgb_agg[grepl(paste0("^", base, "_"), feature),
                     base_feature := base]
}

# Sum mean absolute SHAP within each base feature
importance_xgb_agg <- importance_xgb_agg[, .(shap_xgb = sum(mean_abs_shap)),
                                         by = base_feature]
importance_xgb_agg <- importance_xgb_agg[order(-shap_xgb)]
importance_xgb_agg[, rank_xgb := .I]
setnames(importance_xgb_agg, "base_feature", "feature")

# Strip _woe suffix from logistic features
importance_logistic_clean <- copy(importance_logistic)
importance_logistic_clean[, feature := gsub("_woe$", "", feature)]

# Merge on base feature name
comparison <- merge(
  importance_xgb_agg[, .(feature, rank_xgb, shap_xgb)],
  importance_logistic_clean[, .(feature, rank_logistic = rank,
                                shap_logistic = mean_abs_shap)],
  by = "feature", all = TRUE
)[order(rank_xgb)]
comparison

# --- Global feature importance Table ---
library(stargazer)

# Clean feature names for display
feature_labels <- c(
  "credit_score"               = "Credit Score",
  "orig_interest_rate"         = "Origination Interest Rate",
  "orig_ltv"                   = "Original LTV",
  "num_borrowers"              = "Number of Borrowers",
  "orig_dti"                   = "Origination DTI",
  "hpi_qoq_qlag1"              = "HPI Growth (Lagged)",
  "current_upb"                = "Current UPB",
  "loan_purpose"               = "Loan Purpose",
  "loan_age"                   = "Loan Age",
  "unemployment_rate_lag4"     = "Unemployment Rate (Lagged)",
  "channel"                    = "Channel",
  "orig_loan_term"             = "Original Loan Term",
  "current_delinquency_status" = "Current Delinquency Status",
  "mi_pct"                     = "Mortgage Insurance %",
  "property_type"              = "Property Type",
  "occupancy_status"           = "Occupancy Status",
  "first_time_homebuyer"       = "First-Time Homebuyer",
  "delta_interest_rate"        = "Delta Interest Rate",
  "mod_flag_12m"               = "Modification Flag (12m)",
  "current_deferred_upb"       = "Current Deferred UPB"
)

table_dt <- copy(comparison)
table_dt[, feature_label := feature_labels[feature]]
table_dt[is.na(feature_label), feature_label := feature]
table_dt[, shap_xgb     := round(shap_xgb, 4)]
table_dt[, shap_logistic := round(shap_logistic, 4)]
setorder(table_dt, rank_xgb)

table_df <- as.data.frame(table_dt[, .(feature_label, rank_xgb, shap_xgb,
                                       rank_logistic, shap_logistic)])
colnames(table_df) <- c("Feature", "Rank (XGBoost)", "Mean |SHAP| (XGBoost)",
                        "Rank (Logistic)", "Mean |SHAP| (Logistic)")

stargazer(table_df,
          type = "html",
          summary = FALSE,
          rownames = FALSE,
          title = "Table 4.2: Global Feature Importance Rankings Based on Mean Absolute SHAP Values",
          notes = c("Mean absolute SHAP values are computed on a common random subsample",
                    "of 10,000 test observations for both models. For the XGBoost model,",
                    "one-hot encoded categorical variables are aggregated to the base",
                    "variable level by summing mean absolute SHAP values across dummy",
                    "indicators. Higher mean |SHAP| indicates a greater average contribution",
                    "to predictions on the log-odds scale."),
          notes.align = "l",
          out = file.path(OUTPUT_DIR, "table_4_2_feature_importance.html"))


# --- Plots ---

# --- Beeswarm plot (XGBoost, top 15 features) ---
shap_long_xgb <- shap.prep(shap_contrib = shap_xgb$shap_score,
                           X_train = X_shap_xgb,
                           top_n = 15)

p_beeswarm <- shap.plot.summary(shap_long_xgb) +
  ggtitle("XGBoost SHAP Beeswarm Plot") +
  theme_classic() +
  theme(plot.title = element_text(size = 14, face = "bold"),
        axis.text.y = element_text(size = 10),
        legend.position = "bottom",
        legend.direction = "horizontal",
        legend.title = element_text(size = 10),
        legend.text = element_text(size = 9),
        legend.key.width = unit(1.5, "cm"),
        legend.key.height = unit(0.3, "cm"),
        legend.margin = margin(t = 10, b = 5)) +
  guides(colour = guide_colourbar(
    title = "Feature value",
    title.position = "top",
    title.hjust = 0.5,
    barwidth = 8,
    barheight = 0.5
  ))

ggsave(file.path(OUTPUT_DIR, "shap_beeswarm_xgb.pdf"), p_beeswarm,
       width = 10, height = 7.5)


# --- Beeswarm plot Logistic Regression ---
# Computed on the same 10,000 subsample as XGBoost

top15_log <- importance_logistic$feature[1:15]

shap_long_log <- melt(shap_values_logistic[, ..top15_log],
                      measure.vars = top15_log,
                      variable.name = "variable",
                      value.name = "value")

# Add feature values (WoE-transformed) for coloring
feature_vals_log <- melt(as.data.table(X_shap_woe)[, ..top15_log],
                         measure.vars = top15_log,
                         variable.name = "variable",
                         value.name = "rfvalue")

shap_long_log[, rfvalue := feature_vals_log$rfvalue]

# Order variables by mean absolute SHAP (descending)
shap_long_log[, mean_value := mean(abs(value)), by = variable]
order_levels <- shap_long_log[, .(mean_value = mean(abs(value))),
                              by = variable][order(-mean_value)]$variable
shap_long_log[, variable := factor(variable, levels = order_levels)]

# Subsample for visual readability (from the 10,000 subsample)
set.seed(42)
shap_long_log <- shap_long_log[sample(.N, min(.N, 5000))]

p_beeswarm_log <- ggplot(shap_long_log, aes(x = value, y = variable)) +
  geom_jitter(width = 0, height = 0.4, size = 0.7, alpha = 0.5) +
  labs(title = "Logistic Regression SHAP Beeswarm Plot",
       x = "SHAP value (log-odds contribution)", y = NULL) +
  theme_classic()

ggsave(file.path(OUTPUT_DIR, "shap_beeswarm_logistic.pdf"),
       p_beeswarm_log,
       width = 10, height = 7.5)

# --- Dependence plots for XGB as grid ---

shap_long_full <- shap.prep(shap_contrib = shap_xgb$shap_score, X_train = X_shap_xgb)

p_dep_a <- shap.plot.dependence(shap_long_full, x = "credit_score") +
  ggtitle("(a) Credit Score") +
  theme_classic() +
  theme(plot.title = element_text(size = 11, face = "bold"),
        axis.title = element_text(size = 9))

p_dep_b <- shap.plot.dependence(shap_long_full, x = "orig_interest_rate") +
  ggtitle("(b) Origination Interest Rate") +
  theme_classic() +
  theme(plot.title = element_text(size = 11, face = "bold"),
        axis.title = element_text(size = 9))

p_dep_c <- shap.plot.dependence(shap_long_full, x = "loan_purpose_P") +
  ggtitle("(c) Loan Purpose: Purchase") +
  theme_classic() +
  theme(plot.title = element_text(size = 11, face = "bold"),
        axis.title = element_text(size = 9))

p_dep_d <- shap.plot.dependence(shap_long_full, x = "orig_dti") +
  ggtitle("(d) Origination DTI") +
  theme_classic() +
  theme(plot.title = element_text(size = 11, face = "bold"),
        axis.title = element_text(size = 9))

p_dep_grid <- grid.arrange(p_dep_a, p_dep_b, p_dep_c, p_dep_d,
                           ncol = 2, nrow = 2,
                           top = grid::textGrob("SHAP Dependence Plots",
                                                gp = grid::gpar(fontsize = 13, fontface = "bold")))
ggsave(file.path(OUTPUT_DIR, "shap_dependence_grid.pdf"), p_dep_grid,
       width = 14, height = 10)


# --- Logistic Dependence plots ---
# Computed on the same 10,000 subsample.

log_dep_features <- c("credit_score_woe", "orig_interest_rate_woe",
                      "loan_purpose_woe", "orig_dti_woe")
log_dep_titles <- c("(a) Credit Score (WoE)", "(b) Origination Interest Rate (WoE)",
                    "(c) Loan Purpose (WoE)", "(d) Origination DTI (WoE)")
log_dep_plots <- lapply(seq_along(log_dep_features), function(i) {
  feat <- log_dep_features[i]
  shap_vals <- as.numeric(shap_values_logistic[[feat]])
  feat_vals <- X_shap_woe[, feat]
  
  dt <- data.table(feature_value = feat_vals, shap_value = shap_vals)
  
  # Subsample for visibility (discrete points are too dense)
  set.seed(42)
  dt_sub <- dt[sample(.N, min(.N, 5000))]
  
  # Scale jitter to feature range
  x_range <- diff(range(dt_sub$feature_value, na.rm = TRUE))
  y_range <- diff(range(dt_sub$shap_value, na.rm = TRUE))
  
  ggplot(dt_sub, aes(x = feature_value, y = shap_value)) +
    geom_jitter(width = x_range * 0.03, height = y_range * 0.03,
                size = 1, alpha = 0.5, color = "#4575b4") +
    geom_smooth(method = "lm", se = FALSE, color = "#d73027", linewidth = 0.7) +
    labs(title = log_dep_titles[i],
         x = "Feature value (WoE)", y = "SHAP value") +
    theme_classic() +
    theme(plot.title = element_text(size = 11, face = "bold"),
          axis.title = element_text(size = 9))
})
log_dep_grid <- grid.arrange(grobs = log_dep_plots,
                             ncol = 2, nrow = 2,
                             top = grid::textGrob("Logistic Regression SHAP Dependence Plots",
                                                  gp = grid::gpar(fontsize = 13, fontface = "bold")))

ggsave(file.path(OUTPUT_DIR, "logistic_shap_dependence_grid.pdf"),
       log_dep_grid,
       width = 14, height = 10)

# --- Local explanation plot (both models combined) ---

test_preds <- predict(xgb_model,
                      xgb.DMatrix(data = X_shap_xgb,
                                  label = test_xgb$default_next_12m[shap_idx]))

idx_high <- which.max(test_preds)
idx_low  <- which.min(test_preds)
idx_med  <- which.min(abs(test_preds - median(test_preds)))


# XGBoost high-risk 
shap_high <- as.numeric(shap_xgb$shap_score[idx_high, ])
names(shap_high) <- names(shap_xgb$shap_score)
shap_high_dt <- data.table(feature = names(shap_high), shap_value = shap_high)
shap_high_dt <- shap_high_dt[order(-abs(shap_value))][1:10]
shap_high_dt[, feature := factor(feature, levels = rev(feature))]

# XGBoost low-risk
shap_low <- as.numeric(shap_xgb$shap_score[idx_low, ])
names(shap_low) <- names(shap_xgb$shap_score)
shap_low_dt <- data.table(feature = names(shap_low), shap_value = shap_low)
shap_low_dt <- shap_low_dt[order(-abs(shap_value))][1:10]
shap_low_dt[, feature := factor(feature, levels = rev(feature))]

# XGBoost median-risk
shap_med <- as.numeric(shap_xgb$shap_score[idx_med, ])
names(shap_med) <- names(shap_xgb$shap_score)
shap_med_dt <- data.table(feature = names(shap_med), shap_value = shap_med)
shap_med_dt <- shap_med_dt[order(-abs(shap_value))][1:10]
shap_med_dt[, feature := factor(feature, levels = rev(feature))]

# Logistic (same within-subsample positions as XGBoost) -----------------------
# Index directly with idx_high/med/low (not shap_idx[idx_*]) because
# shap_values_logistic now lives on the subsample.

shap_high_log <- as.numeric(shap_values_logistic[idx_high, ])
names(shap_high_log) <- woe_features
shap_high_log_dt <- data.table(feature = names(shap_high_log), shap_value = shap_high_log)
shap_high_log_dt <- shap_high_log_dt[order(-abs(shap_value))][1:10]
shap_high_log_dt[, feature := factor(feature, levels = rev(feature))]

shap_low_log <- as.numeric(shap_values_logistic[idx_low, ])
names(shap_low_log) <- woe_features
shap_low_log_dt <- data.table(feature = names(shap_low_log), shap_value = shap_low_log)
shap_low_log_dt <- shap_low_log_dt[order(-abs(shap_value))][1:10]
shap_low_log_dt[, feature := factor(feature, levels = rev(feature))]

shap_med_log <- as.numeric(shap_values_logistic[idx_med, ])
names(shap_med_log) <- woe_features
shap_med_log_dt <- data.table(feature = names(shap_med_log), shap_value = shap_med_log)
shap_med_log_dt <- shap_med_log_dt[order(-abs(shap_value))][1:10]
shap_med_log_dt[, feature := factor(feature, levels = rev(feature))]

# Common x-axis limits across ALL six panels
x_range_all <- range(c(shap_high_dt$shap_value, shap_low_dt$shap_value,
                       shap_med_dt$shap_value,
                       shap_high_log_dt$shap_value, shap_low_log_dt$shap_value,
                       shap_med_log_dt$shap_value))
x_lim_all <- c(min(x_range_all) - 0.15, max(x_range_all) + 0.15)

# Shared theme and colour scale
waterfall_theme <- theme_classic(base_size = 11) +
  theme(plot.title = element_text(size = 11, face = "bold", hjust = 0.5),
        axis.text.y = element_text(size = 9),
        axis.text.x = element_text(size = 9),
        axis.title.x = element_text(size = 10),
        legend.position = "none",
        plot.margin = margin(t = 5, r = 10, b = 5, l = 5))

fill_scale <- scale_fill_manual(
  values = c("TRUE" = "#D32F2F", "FALSE" = "#1976D2"),
  labels = c("TRUE" = "Increases risk", "FALSE" = "Decreases risk")
)

p_a <- ggplot(shap_high_dt, aes(x = shap_value, y = feature, fill = shap_value > 0)) +
  geom_col(width = 0.7) + fill_scale +
  labs(title = "(a) XGBoost \u2014 Highest PD", x = "SHAP Value", y = NULL) +
  xlim(x_lim_all) + waterfall_theme

p_b <- ggplot(shap_med_dt, aes(x = shap_value, y = feature, fill = shap_value > 0)) +
  geom_col(width = 0.7) + fill_scale +
  labs(title = "(b) XGBoost \u2014 Median PD", x = "SHAP Value", y = NULL) +
  xlim(x_lim_all) + waterfall_theme

p_c <- ggplot(shap_low_dt, aes(x = shap_value, y = feature, fill = shap_value > 0)) +
  geom_col(width = 0.7) + fill_scale +
  labs(title = "(c) XGBoost \u2014 Lowest PD", x = "SHAP Value", y = NULL) +
  xlim(x_lim_all) + waterfall_theme

p_d <- ggplot(shap_high_log_dt, aes(x = shap_value, y = feature, fill = shap_value > 0)) +
  geom_col(width = 0.7) + fill_scale +
  labs(title = "(d) Logistic Regression \u2014 Highest PD", x = "SHAP Value", y = NULL) +
  xlim(x_lim_all) + waterfall_theme

p_e <- ggplot(shap_med_log_dt, aes(x = shap_value, y = feature, fill = shap_value > 0)) +
  geom_col(width = 0.7) + fill_scale +
  labs(title = "(e) Logistic Regression \u2014 Median PD", x = "SHAP Value", y = NULL) +
  xlim(x_lim_all) + waterfall_theme

p_f <- ggplot(shap_low_log_dt, aes(x = shap_value, y = feature, fill = shap_value > 0)) +
  geom_col(width = 0.7) + fill_scale +
  labs(title = "(f) Logistic Regression \u2014 Lowest PD", x = "SHAP Value", y = NULL) +
  xlim(x_lim_all) + waterfall_theme

# Extract shared legend from one panel
p_legend_source <- ggplot(shap_high_dt, aes(x = shap_value, y = feature, fill = shap_value > 0)) +
  geom_col() + fill_scale +
  theme(legend.position = "bottom",
        legend.direction = "horizontal",
        legend.title = element_blank(),
        legend.text = element_text(size = 10),
        legend.key.size = unit(0.5, "cm"))

# Extract legend grob
g <- ggplotGrob(p_legend_source)
legend_grob <- g$grobs[[which(sapply(g$grobs, function(x) x$name) == "guide-box")]]

p_combined <- grid.arrange(
  arrangeGrob(p_a, p_b, p_c, p_d, p_e, p_f, ncol = 3, nrow = 2),
  legend_grob,
  nrow = 2, heights = c(10, 1),
  top = grid::textGrob("SHAP Local Explanations",
                       gp = grid::gpar(fontsize = 14, fontface = "bold"))
)

ggsave(file.path(OUTPUT_DIR, "shap_local_explanations_combined.pdf"), p_combined,
       width = 18, height = 10)

# ---SHAP interaction for the three local explanation observations featured---

n_feat <- length(xgb_features)

X_three <- X_shap_xgb[c(idx_high, idx_med, idx_low), , drop = FALSE]
shap_int_three <- predict(xgb_model,
                          newdata = X_three,
                          predinteraction = TRUE)

# Strip bias row/column
shap_int_three <- shap_int_three[, 1:n_feat, 1:n_feat]

# Extract top interactions for a single observation
top_interactions_obs <- function(mat, label, k = 5) {
  diag(mat) <- NA  # exclude main effects
  dt <- as.data.table(as.table(mat))
  setnames(dt, c("feature_1", "feature_2", "interaction"))
  dt <- dt[!is.na(interaction)]
  dt[, pair := paste(pmin(as.character(feature_1), as.character(feature_2)),
                     pmax(as.character(feature_1), as.character(feature_2)),
                     sep = " x ")]
  dt <- unique(dt, by = "pair")
  dt <- dt[order(-abs(interaction))][1:k]
  dt[, profile := label]
  dt[, .(profile, pair, interaction)]
}

top_high <- top_interactions_obs(shap_int_three[1, , ], "Highest PD")
top_med  <- top_interactions_obs(shap_int_three[2, , ], "Median PD")
top_low  <- top_interactions_obs(shap_int_three[3, , ], "Lowest PD")

top_three <- rbindlist(list(top_high, top_med, top_low))
top_three

# --- Table for top interactions ---

top_three_table <- copy(top_three)
top_three_table[, pair := gsub("_", " ", pair)]
top_three_table[, interaction := round(interaction, 4)]

stargazer(as.data.frame(top_three_table[, .(profile, pair, interaction)]),
          type = "html", summary = FALSE, rownames = FALSE,
          title = "Top SHAP Interaction Values for the Three Local Explanation Observations",
          out = file.path(OUTPUT_DIR, "appendix_shap_interactions_local.html"))

# Subsample validation XGB
set.seed(99)
shap_idx2   <- sample(nrow(X_test_xgb), 10000)
X_shap_xgb2 <- X_test_xgb[shap_idx2, ]

shap_xgb2 <- shap.values(xgb_model = xgb_model, X_train = X_shap_xgb2)

mean_shap2 <- shap_xgb2$mean_shap_score
importance2 <- data.table(
  feature = names(mean_shap2),
  mean_abs_shap = as.numeric(mean_shap2)
)[order(-mean_abs_shap)]
importance2[, rank2 := .I]

rank_check <- merge(
  importance_xgb[, .(feature, rank1 = rank)],
  importance2[, .(feature, rank2)],
  by = "feature"
)

spearman_rho <- cor(rank_check$rank1, rank_check$rank2, method = "spearman")

# Subsample validation logistic
X_shap_woe2 <- X_test_woe[shap_idx2, ]

shap_values_logistic2 <- sweep(X_shap_woe2, 2, train_means, "-")
shap_values_logistic2 <- sweep(shap_values_logistic2, 2, coefs, "*")
shap_values_logistic2 <- as.data.table(shap_values_logistic2)

mean_shap_log2 <- sort(colMeans(abs(shap_values_logistic2)), decreasing = TRUE)
importance_logistic2 <- data.table(
  feature = names(mean_shap_log2),
  mean_abs_shap = as.numeric(mean_shap_log2)
)[order(-mean_abs_shap)]
importance_logistic2[, rank2 := .I]

rank_check_log <- merge(
  importance_logistic[, .(feature, rank1 = rank)],
  importance_logistic2[, .(feature, rank2)],
  by = "feature"
)

spearman_rho_log <- cor(rank_check_log$rank1, rank_check_log$rank2,
                        method = "spearman")
spearman_rho_log

# --- Save ---
saveRDS(shap_values_xgb,       file.path(OUTPUT_DIR, "shap_xgb.rds"))
saveRDS(shap_values_logistic,  file.path(OUTPUT_DIR, "shap_logistic.rds"))
saveRDS(importance_xgb,        file.path(OUTPUT_DIR, "importance_xgb.rds"))
saveRDS(importance_logistic,   file.path(OUTPUT_DIR, "importance_logistic.rds"))
saveRDS(comparison,            file.path(OUTPUT_DIR, "importance_comparison.rds"))

shap.plot.dependence(shap_long_full, x = "loan_age") +
  ggtitle("Loan Age") +
  theme_classic()

# --- Check Feature Order: Local Explanations ---

# Re-extract feature order by sorting  on absolute SHAP value 
get_order <- function(dt) as.character(dt$feature)[order(-abs(dt$shap_value))]

order_table <- data.table(
  rank     = 1:10,
  xgb_high = get_order(shap_high_dt),
  xgb_med  = get_order(shap_med_dt),
  xgb_low  = get_order(shap_low_dt),
  log_high = get_order(shap_high_log_dt),
  log_med  = get_order(shap_med_log_dt),
  log_low  = get_order(shap_low_log_dt)
)

order_table