1 ranger, vivid

## https://cran.r-project.org/web/packages/nestedcv/vignettes/nestedcv_shap.html
library(nestedcv)
library(mlbench)  # Boston housing dataset

data(BostonHousing2)
dat <- BostonHousing2
y <- dat$cmedv
x <- subset(dat, select = -c(cmedv, medv, town, chas))

# Fit a glmnet model using nested CV
set.seed(1, "L'Ecuyer-CMRG")
fit <- nestcv.glmnet(y, x, family = "gaussian",
                     min_1se = 1, alphaSet = 1, cv.cores = 2)
vs <- var_stability(fit)
vs
##                 mean           sd         sem frequency sign direction
## lon     48.393372145 13.359312172 4.027984176        11   -1  Negative
## rm      35.277524448 12.364884451 3.728152936        11    1  Positive
## ptratio  5.660722405  1.767597466 0.532950689        11   -1  Negative
## lstat    4.460941717  1.651898823 0.498066235        11   -1  Negative
## nox      3.632481722  5.345960957 1.611867876         4   -1  Negative
## dis      1.266762305  1.029387844 0.310372113         8   -1  Negative
## lat      1.135429569  3.084782762 0.930096998         2    1  Positive
## crim     0.120941548  0.085181803 0.025683280         9   -1  Negative
## b        0.044978058  0.009116406 0.002748700        11    1  Positive
## tax      0.005062703  0.005459025 0.001645958         8   -1  Negative
## zn       0.001783381  0.005914805 0.001783381         1    1  Positive
plot_var_stability(fit)

# overlay directionality using colour
p1 <- plot_var_stability(fit, final = FALSE, direction = 1)

# show directionality with the sign of the variable importance
p2 <- plot_var_stability(fit, final = FALSE, percent = F)

ggpubr::ggarrange(p1, p2, ncol=2)

library(ggplot2)
# change bubble colour scheme
p1 + scale_fill_manual(values=c("orange", "green3"))

library(fastshap)
# Generate SHAP values using fastshap::explain
# Only using 5 repeats here for speed, but recommend higher values of nsim
sh <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet, nsim = 5)

# Plot overall variable importance
p1 <- autoplot(sh)

# Overlay main direction
p2 <- plot_shap_bar(sh, x)
## Variables with mean(|SHAP|)=0: tract, lat, zn, indus, age, rad

ggpubr::ggarrange(p1, p2, ncol = 2)

# Plot beeswarm plot
plot_shap_beeswarm(sh, x, size = 1)

# Only 3 outer folds to speed up process
fit <- nestcv.train(y, x,
                    method = "gbm",
                    n_outer_folds = 3, cv.cores = 2)
## Loading required package: lattice

# Only using 5 repeats here for speed, but recommend higher values of nsim
sh <- explain(fit, X=x, pred_wrapper = pred_train, nsim = 5)
plot_shap_beeswarm(sh, x, size = 1)

data("iris")
dat <- iris
y <- dat$Species
x <- dat[, 1:4]

# Only 3 outer folds to speed up process
fit <- nestcv.glmnet(y, x, family = "multinomial", n_outer_folds = 3, alphaSet = 0.6)


# SHAP values for each of the 3 classes
sh1 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class1, nsim = 5)
sh2 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class2, nsim = 5)
sh3 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class3, nsim = 5)

s1 <- plot_shap_bar(sh1, x, sort = FALSE) +
  ggtitle("Setosa")
s2 <- plot_shap_bar(sh2, x, sort = FALSE) +
  ggtitle("Versicolor")
s3 <- plot_shap_bar(sh3, x, sort = FALSE) +
  ggtitle("Virginica")

ggpubr::ggarrange(s1, s2, s3, ncol=3, legend = "bottom", common.legend = TRUE)

s1 <- plot_shap_beeswarm(sh1, x, sort = FALSE, cex = 0.7) +
  ggtitle("Setosa")
s2 <- plot_shap_beeswarm(sh2, x, sort = FALSE, cex = 0.7) +
  ggtitle("Versicolor")
s3 <- plot_shap_beeswarm(sh3, x, sort = FALSE, cex = 0.7) +
  ggtitle("Virginica")

ggpubr::ggarrange(s1, s2, s3, ncol=3, legend = "right", common.legend = TRUE)