This notebook is to find key drivers related to a target in a dataset by building an explortory model, decision tree or random forest.
df_raw <- read_csv("~/projects/data/datasets/titanic/train.csv")
## Parsed with column specification:
## cols(
## PassengerId = col_double(),
## Survived = col_double(),
## Pclass = col_double(),
## Name = col_character(),
## Sex = col_character(),
## Age = col_double(),
## SibSp = col_double(),
## Parch = col_double(),
## Ticket = col_character(),
## Fare = col_double(),
## Cabin = col_character(),
## Embarked = col_character()
## )
## ── Data Summary ────────────────────────
## Values
## Name df_raw
## Number of rows 891
## Number of columns 12
## _______________________
## Column type frequency:
## character 5
## numeric 7
## ________________________
## Group variables None
##
## ── Variable type: character ────────────────────────────────────────────────────
## # A tibble: 5 × 8
## skim_variable n_missing complete_rate min max empty n_unique whitespace
## * <chr> <int> <dbl> <int> <int> <int> <int> <int>
## 1 name 0 1 12 82 0 891 0
## 2 sex 0 1 4 6 0 2 0
## 3 ticket 0 1 3 18 0 681 0
## 4 cabin 687 0.229 1 15 0 147 0
## 5 embarked 2 0.998 1 1 0 3 0
##
## ── Variable type: numeric ──────────────────────────────────────────────────────
## # A tibble: 7 × 11
## skim_variable n_missing complete_rate mean sd p0 p25 p50 p75
## * <chr> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 passenger_id 0 1 446 257. 1 224. 446 668.
## 2 survived 0 1 0.384 0.487 0 0 0 1
## 3 pclass 0 1 2.31 0.836 1 2 3 3
## 4 age 177 0.801 29.7 14.5 0.42 20.1 28 38
## 5 sib_sp 0 1 0.523 1.10 0 0 0 1
## 6 parch 0 1 0.382 0.806 0 0 0 0
## 7 fare 0 1 32.2 49.7 0 7.91 14.5 31
## p100 hist
## * <dbl> <chr>
## 1 891 ▇▇▇▇▇
## 2 1 ▇▁▁▁▅
## 3 3 ▃▁▃▁▇
## 4 80 ▂▇▅▂▁
## 5 8 ▇▁▁▁▁
## 6 6 ▇▁▁▁▁
## 7 512. ▇▁▁▁▁
target <- "survived"
df_raw %>% count(.data[[target]]) %>% # or count(!!sym(target))
mutate(prop = n/sum(n))
## # A tibble: 2 × 3
## survived n prop
## <dbl> <int> <dbl>
## 1 0 549 0.616
## 2 1 342 0.384
# convert all charactor to factor
df_model <- mutate_if(df_raw, is.character, as.factor)
# remove constant features
df_model <- mlr::removeConstantFeatures(df_model)
# delete factors more than 10 levels
df_model <- df_model[,sapply(df_model,is.factor)==FALSE | between(sapply(df_model,nlevels),2,10)]
# Drop features
feat_drop <- c("passenger_id")
df_model <- df_model[,-which(names(df_model) %in% feat_drop)]
#impute missing
imp<-mlr::impute(df_model, classes = list(
factor = imputeMode(),
integer = imputeMedian(),
numeric=imputeMean())
,dummy.classes = c("integer","factor","numeric")
,dummy.type = "numeric")
df_model<-imp$data
#dummy var
df_model <- createDummyFeatures(df_model, target = target, method = "reference")
df_model <- df_model %>% janitor::clean_names()
res <- cor(df_model)
col <- colorRampPalette(c("#BB4444", "#EE9988", "#FFFFFF", "#77AADD", "#4477AA"))
corrplot::corrplot(res,type="upper",
sig.level = 0.05,
method="number"
)
tree <- rpart::rpart(as.formula(df_model[[target]] ~ .)
,data = df_model %>% select(-target)
,method = "class"
,model = T
,control = rpart.control(minsplit = 10, cp = 0))
## Note: Using an external vector in selections is ambiguous.
## ℹ Use `all_of(target)` instead of `target` to silence this message.
## ℹ See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>.
## This message is displayed once per session.
depth <- 4
tree_prune <- prune(tree,cp=tree$cptable[depth,"CP"])
rpart.plot(tree_prune,
type = 4,
clip.right.labs = FALSE,
branch = .3,
under = TRUE,
under.cex = 1.1,
tweak = 1.3
)
# visNetwork::visTree(tree_prune,fallenLeaves=T,height="1000px",width = "200%",
# export = F,legendPosition = "right",
# highlightNearest = list(enabled = TRUE, degree = list(from = 50000, to = 0),
# hover = TRUE, algorithm = "hierarchical"))
DT::datatable(
as.data.frame(tree$variable.importance)
,caption = 'Table: Variable importance summary'
,colnames = c('Variable name', 'Importance')
,selection = 'single'
,editable = F
) %>%
DT::formatRound(columns=c(1:2),digits = 3)
set.seed(42)
rf <-randomForest(as.factor(survived)~.,data=df_model, ntree=500,importance=T)
print(rf)
##
## Call:
## randomForest(formula = as.factor(survived) ~ ., data = df_model, ntree = 500, importance = T)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 3
##
## OOB estimate of error rate: 16.61%
## Confusion matrix:
## 0 1 class.error
## 0 503 46 0.08378871
## 1 102 240 0.29824561
plot(rf)
importance(rf,type=1) %>% as.tibble(rownames="var") %>%
arrange(desc(MeanDecreaseAccuracy)) %>%
mutate(imp_pct = round(MeanDecreaseAccuracy/sum(MeanDecreaseAccuracy),digits = 3))
## Warning: `as.tibble()` was deprecated in tibble 2.0.0.
## Please use `as_tibble()` instead.
## The signature and semantics have changed, see `?as_tibble`.
## # A tibble: 10 × 3
## var MeanDecreaseAccuracy imp_pct
## <chr> <dbl> <dbl>
## 1 male 89.9 0.319
## 2 pclass 44.4 0.158
## 3 fare 34.3 0.122
## 4 age 32.6 0.116
## 5 sib_sp 19.8 0.07
## 6 embarked_s 16.7 0.059
## 7 parch 15.2 0.054
## 8 age_dummy 15.1 0.054
## 9 embarked_q 12.2 0.043
## 10 embarked_dummy 1.42 0.005
varImpPlot(rf,type = 1)
imp <- importance(rf)
impvar <- rownames(imp)[order(imp[, 1], decreasing=TRUE)]
for (i in seq_along(impvar)) {
randomForest::partialPlot(rf,df_model,impvar[i],which.class=1,
xlab=impvar[i],
main=paste("Partial Dependence on", impvar[i]))
}
library(pdp)
##
## Attaching package: 'pdp'
## The following object is masked from 'package:purrr':
##
## partial
features <- df_model %>% select(-target) %>% colnames()
pdps <- list()
for (feature in features) {
pdps[[feature]] <- partial(rf, pred.var = feature, which.class = 2,train = df_model, prob = T,
type = c("classification"),
plot = TRUE, rug = TRUE,
alpha = 0.8,plot.engine = "ggplot2"
)
print(pdps[[feature]])
}
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]`
## instead.
## Warning: Use of `x.rug[[1L]]` is discouraged. Use `.data[[1L]]` instead.
#grid.arrange(grobs = pdps, ncol = 1)
How does it work? Below is the step wise working of boruta algorithm:
Firstly, it adds randomness to the given data set by creating shuffled copies of all features (which are called shadow features).
Then, it trains a random forest classifier on the extended data set and applies a feature importance measure (the default is Mean Decrease Accuracy) to evaluate the importance of each feature where higher means more important.
At every iteration, it checks whether a real feature has a higher importance than the best of its shadow features (i.e. whether the feature has a higher Z score than the maximum Z score of its shadow features) and constantly removes features which are deemed highly unimportant.
Finally, the algorithm stops either when all features gets confirmed or rejected or it reaches a specified limit of random forest runs.
set.seed(123)
boruta.train <- Boruta(survived ~., data = df_model, doTrace = 0)
print(boruta.train)
## Boruta performed 27 iterations in 4.396305 secs.
## 9 attributes confirmed important: age, age_dummy, embarked_q,
## embarked_s, fare and 4 more;
## 1 attributes confirmed unimportant: embarked_dummy;
plot(boruta.train, xlab = "", xaxt = "n")
lz<-lapply(1:ncol(boruta.train$ImpHistory),function(i) boruta.train$ImpHistory[is.finite(boruta.train$ImpHistory[,i]),i])
names(lz) <- colnames(boruta.train$ImpHistory)
Labels <- sort(sapply(lz,median))
axis(side = 1,las=2,labels = names(Labels),
at = 1:ncol(boruta.train$ImpHistory), cex.axis = 0.7)
final.boruta <- TentativeRoughFix(boruta.train)
## Warning in TentativeRoughFix(boruta.train): There are no Tentative attributes!
## Returning original object.
print(final.boruta)
## Boruta performed 27 iterations in 4.396305 secs.
## 9 attributes confirmed important: age, age_dummy, embarked_q,
## embarked_s, fare and 4 more;
## 1 attributes confirmed unimportant: embarked_dummy;
getSelectedAttributes(final.boruta, withTentative = F)
## [1] "pclass" "age" "sib_sp" "parch" "fare"
## [6] "age_dummy" "male" "embarked_q" "embarked_s"
boruta.df <- attStats(final.boruta)
print(boruta.df %>% arrange(desc(meanImp)))
## meanImp medianImp minImp maxImp normHits decision
## male 77.934124 77.915829 71.782546 82.622922 1.0000000 Confirmed
## pclass 34.308028 34.070058 31.164675 38.328537 1.0000000 Confirmed
## fare 28.250358 28.143484 24.998817 31.075931 1.0000000 Confirmed
## age 23.718310 23.831899 21.119384 26.430083 1.0000000 Confirmed
## sib_sp 16.758912 16.755684 14.910549 19.652687 1.0000000 Confirmed
## parch 13.714816 13.714328 12.094403 15.582184 1.0000000 Confirmed
## embarked_s 9.818906 9.779211 6.717029 12.418427 1.0000000 Confirmed
## embarked_q 6.873162 6.691482 4.275020 8.892728 1.0000000 Confirmed
## age_dummy 6.547832 6.467146 4.671393 9.680775 0.9629630 Confirmed
## embarked_dummy 1.264387 1.460648 -1.042565 2.053141 0.1851852 Rejected