Objective

This notebook is to find key drivers related to a target in a dataset by building an explortory model, decision tree or random forest.

Load data

df_raw <- read_csv("~/projects/data/datasets/titanic/train.csv")
## Parsed with column specification:
## cols(
##   PassengerId = col_double(),
##   Survived = col_double(),
##   Pclass = col_double(),
##   Name = col_character(),
##   Sex = col_character(),
##   Age = col_double(),
##   SibSp = col_double(),
##   Parch = col_double(),
##   Ticket = col_character(),
##   Fare = col_double(),
##   Cabin = col_character(),
##   Embarked = col_character()
## )

Explore data

## ── Data Summary ────────────────────────
##                            Values
## Name                       df_raw
## Number of rows             891   
## Number of columns          12    
## _______________________          
## Column type frequency:           
##   character                5     
##   numeric                  7     
## ________________________         
## Group variables            None  
## 
## ── Variable type: character ────────────────────────────────────────────────────
## # A tibble: 5 × 8
##   skim_variable n_missing complete_rate   min   max empty n_unique whitespace
## * <chr>             <int>         <dbl> <int> <int> <int>    <int>      <int>
## 1 name                  0         1        12    82     0      891          0
## 2 sex                   0         1         4     6     0        2          0
## 3 ticket                0         1         3    18     0      681          0
## 4 cabin               687         0.229     1    15     0      147          0
## 5 embarked              2         0.998     1     1     0        3          0
## 
## ── Variable type: numeric ──────────────────────────────────────────────────────
## # A tibble: 7 × 11
##   skim_variable n_missing complete_rate    mean      sd    p0    p25   p50   p75
## * <chr>             <int>         <dbl>   <dbl>   <dbl> <dbl>  <dbl> <dbl> <dbl>
## 1 passenger_id          0         1     446     257.     1    224.   446    668.
## 2 survived              0         1       0.384   0.487  0      0      0      1 
## 3 pclass                0         1       2.31    0.836  1      2      3      3 
## 4 age                 177         0.801  29.7    14.5    0.42  20.1   28     38 
## 5 sib_sp                0         1       0.523   1.10   0      0      0      1 
## 6 parch                 0         1       0.382   0.806  0      0      0      0 
## 7 fare                  0         1      32.2    49.7    0      7.91  14.5   31 
##    p100 hist 
## * <dbl> <chr>
## 1  891  ▇▇▇▇▇
## 2    1  ▇▁▁▁▅
## 3    3  ▃▁▃▁▇
## 4   80  ▂▇▅▂▁
## 5    8  ▇▁▁▁▁
## 6    6  ▇▁▁▁▁
## 7  512. ▇▁▁▁▁

Target variable/label

target <- "survived"

df_raw %>% count(.data[[target]]) %>% # or count(!!sym(target)) 
  mutate(prop = n/sum(n))
## # A tibble: 2 × 3
##   survived     n  prop
##      <dbl> <int> <dbl>
## 1        0   549 0.616
## 2        1   342 0.384

Preprocess raw data

# convert all charactor to factor
df_model <- mutate_if(df_raw, is.character, as.factor)
# remove constant features
df_model <- mlr::removeConstantFeatures(df_model)
# delete factors more than 10 levels
df_model <- df_model[,sapply(df_model,is.factor)==FALSE | between(sapply(df_model,nlevels),2,10)]
# Drop features
feat_drop <- c("passenger_id")
df_model <- df_model[,-which(names(df_model) %in% feat_drop)]


#impute missing
imp<-mlr::impute(df_model, classes = list(
  factor = imputeMode(), 
  integer = imputeMedian(),
  numeric=imputeMean())
  ,dummy.classes = c("integer","factor","numeric")
  ,dummy.type = "numeric")

df_model<-imp$data

#dummy var
df_model <- createDummyFeatures(df_model, target = target, method = "reference")

df_model <- df_model %>% janitor::clean_names()

Build an exploratory model to understand the relationships

Check correlation

res <- cor(df_model)

col <- colorRampPalette(c("#BB4444", "#EE9988", "#FFFFFF", "#77AADD", "#4477AA"))
corrplot::corrplot(res,type="upper",
                   sig.level = 0.05,
                   method="number"
         )

Build a simple decision tree

tree <- rpart::rpart(as.formula(df_model[[target]] ~ .) 
              ,data = df_model %>% select(-target)
              ,method = "class"
              ,model = T
              ,control = rpart.control(minsplit = 10, cp = 0))
## Note: Using an external vector in selections is ambiguous.
## ℹ Use `all_of(target)` instead of `target` to silence this message.
## ℹ See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>.
## This message is displayed once per session.
depth <- 4
tree_prune <- prune(tree,cp=tree$cptable[depth,"CP"])
rpart.plot(tree_prune,
               type = 4,
               clip.right.labs = FALSE,
               branch = .3,
               under = TRUE,
               under.cex = 1.1,
               tweak = 1.3
    )

# visNetwork::visTree(tree_prune,fallenLeaves=T,height="1000px",width = "200%",
#                         export = F,legendPosition = "right",
#                         highlightNearest = list(enabled = TRUE, degree = list(from = 50000, to = 0), 
#                                                 hover = TRUE, algorithm = "hierarchical"))

Variable Importance

 DT::datatable(
      as.data.frame(tree$variable.importance) 
    ,caption = 'Table: Variable importance summary'
    ,colnames = c('Variable name', 'Importance')
    ,selection = 'single'
    ,editable = F
    ) %>%
    DT::formatRound(columns=c(1:2),digits = 3)

Build a random forest model

set.seed(42)
rf <-randomForest(as.factor(survived)~.,data=df_model, ntree=500,importance=T) 
  • Diagnostics
print(rf)
## 
## Call:
##  randomForest(formula = as.factor(survived) ~ ., data = df_model,      ntree = 500, importance = T) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 3
## 
##         OOB estimate of  error rate: 16.61%
## Confusion matrix:
##     0   1 class.error
## 0 503  46  0.08378871
## 1 102 240  0.29824561
plot(rf)

importance(rf,type=1) %>% as.tibble(rownames="var") %>% 
  arrange(desc(MeanDecreaseAccuracy)) %>% 
  mutate(imp_pct = round(MeanDecreaseAccuracy/sum(MeanDecreaseAccuracy),digits = 3))
## Warning: `as.tibble()` was deprecated in tibble 2.0.0.
## Please use `as_tibble()` instead.
## The signature and semantics have changed, see `?as_tibble`.
## # A tibble: 10 × 3
##    var            MeanDecreaseAccuracy imp_pct
##    <chr>                         <dbl>   <dbl>
##  1 male                          89.9    0.319
##  2 pclass                        44.4    0.158
##  3 fare                          34.3    0.122
##  4 age                           32.6    0.116
##  5 sib_sp                        19.8    0.07 
##  6 embarked_s                    16.7    0.059
##  7 parch                         15.2    0.054
##  8 age_dummy                     15.1    0.054
##  9 embarked_q                    12.2    0.043
## 10 embarked_dummy                 1.42   0.005
varImpPlot(rf,type = 1)

RandomForest PDP plots

imp <- importance(rf)
impvar <- rownames(imp)[order(imp[, 1], decreasing=TRUE)]
for (i in seq_along(impvar)) {
  randomForest::partialPlot(rf,df_model,impvar[i],which.class=1,
                            xlab=impvar[i],
                main=paste("Partial Dependence on", impvar[i]))
}

PDP on probability

library(pdp)
## 
## Attaching package: 'pdp'
## The following object is masked from 'package:purrr':
## 
##     partial
features <- df_model %>% select(-target) %>% colnames()
pdps <- list()
for (feature in features) {
  pdps[[feature]] <- partial(rf, pred.var = feature, which.class = 2,train = df_model, prob = T,
                             type = c("classification"),
                             plot = TRUE, rug = TRUE,
                             alpha = 0.8,plot.engine = "ggplot2"
                             )
  print(pdps[[feature]])
}
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.

#grid.arrange(grobs = pdps, ncol = 1)

Use Boruta to assess variable importance

How does it work? Below is the step wise working of boruta algorithm:

  • Firstly, it adds randomness to the given data set by creating shuffled copies of all features (which are called shadow features).

  • Then, it trains a random forest classifier on the extended data set and applies a feature importance measure (the default is Mean Decrease Accuracy) to evaluate the importance of each feature where higher means more important.

  • At every iteration, it checks whether a real feature has a higher importance than the best of its shadow features (i.e. whether the feature has a higher Z score than the maximum Z score of its shadow features) and constantly removes features which are deemed highly unimportant.

  • Finally, the algorithm stops either when all features gets confirmed or rejected or it reaches a specified limit of random forest runs.

set.seed(123)

boruta.train <- Boruta(survived ~., data = df_model, doTrace = 0)
print(boruta.train)
## Boruta performed 27 iterations in 4.396305 secs.
##  9 attributes confirmed important: age, age_dummy, embarked_q,
## embarked_s, fare and 4 more;
##  1 attributes confirmed unimportant: embarked_dummy;
plot(boruta.train, xlab = "", xaxt = "n")
lz<-lapply(1:ncol(boruta.train$ImpHistory),function(i) boruta.train$ImpHistory[is.finite(boruta.train$ImpHistory[,i]),i])
names(lz) <- colnames(boruta.train$ImpHistory)  
Labels <- sort(sapply(lz,median))
axis(side = 1,las=2,labels = names(Labels),
at = 1:ncol(boruta.train$ImpHistory), cex.axis = 0.7)

final.boruta <- TentativeRoughFix(boruta.train)
## Warning in TentativeRoughFix(boruta.train): There are no Tentative attributes!
## Returning original object.
print(final.boruta)
## Boruta performed 27 iterations in 4.396305 secs.
##  9 attributes confirmed important: age, age_dummy, embarked_q,
## embarked_s, fare and 4 more;
##  1 attributes confirmed unimportant: embarked_dummy;
getSelectedAttributes(final.boruta, withTentative = F)
## [1] "pclass"     "age"        "sib_sp"     "parch"      "fare"      
## [6] "age_dummy"  "male"       "embarked_q" "embarked_s"
boruta.df <- attStats(final.boruta)
print(boruta.df %>%  arrange(desc(meanImp)))
##                  meanImp medianImp    minImp    maxImp  normHits  decision
## male           77.934124 77.915829 71.782546 82.622922 1.0000000 Confirmed
## pclass         34.308028 34.070058 31.164675 38.328537 1.0000000 Confirmed
## fare           28.250358 28.143484 24.998817 31.075931 1.0000000 Confirmed
## age            23.718310 23.831899 21.119384 26.430083 1.0000000 Confirmed
## sib_sp         16.758912 16.755684 14.910549 19.652687 1.0000000 Confirmed
## parch          13.714816 13.714328 12.094403 15.582184 1.0000000 Confirmed
## embarked_s      9.818906  9.779211  6.717029 12.418427 1.0000000 Confirmed
## embarked_q      6.873162  6.691482  4.275020  8.892728 1.0000000 Confirmed
## age_dummy       6.547832  6.467146  4.671393  9.680775 0.9629630 Confirmed
## embarked_dummy  1.264387  1.460648 -1.042565  2.053141 0.1851852  Rejected