library(mdsr)
## Warning: package 'mdsr' was built under R version 4.4.3
library(dplyr)
## Warning: package 'dplyr' was built under R version 4.4.3
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(mosaic)
## Warning: package 'mosaic' was built under R version 4.4.3
## Registered S3 method overwritten by 'mosaic':
##   method                           from   
##   fortify.SpatialPolygonsDataFrame ggplot2
## 
## The 'mosaic' package masks several functions from core packages in order to add 
## additional features.  The original behavior of these functions should not be affected by this.
## 
## Attaching package: 'mosaic'
## The following object is masked from 'package:Matrix':
## 
##     mean
## The following object is masked from 'package:ggplot2':
## 
##     stat
## The following objects are masked from 'package:dplyr':
## 
##     count, do, tally
## The following objects are masked from 'package:stats':
## 
##     binom.test, cor, cor.test, cov, fivenum, IQR, median, prop.test,
##     quantile, sd, t.test, var
## The following objects are masked from 'package:base':
## 
##     max, mean, min, prod, range, sample, sum
census <- read.csv("http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", header = FALSE)

names(census) <- c("age", "workclass", "fnlwgt", "education", "education.num", "marital.status", "occupation", "relationship", "race",
                   "sex", "capital.gain", "capital.loss", "hours.per.week", "native.country", "income")

summary(census)
##       age         workclass             fnlwgt         education        
##  Min.   :17.00   Length:32561       Min.   :  12285   Length:32561      
##  1st Qu.:28.00   Class :character   1st Qu.: 117827   Class :character  
##  Median :37.00   Mode  :character   Median : 178356   Mode  :character  
##  Mean   :38.58                      Mean   : 189778                     
##  3rd Qu.:48.00                      3rd Qu.: 237051                     
##  Max.   :90.00                      Max.   :1484705                     
##  education.num   marital.status      occupation        relationship      
##  Min.   : 1.00   Length:32561       Length:32561       Length:32561      
##  1st Qu.: 9.00   Class :character   Class :character   Class :character  
##  Median :10.00   Mode  :character   Mode  :character   Mode  :character  
##  Mean   :10.08                                                           
##  3rd Qu.:12.00                                                           
##  Max.   :16.00                                                           
##      race               sex             capital.gain    capital.loss   
##  Length:32561       Length:32561       Min.   :    0   Min.   :   0.0  
##  Class :character   Class :character   1st Qu.:    0   1st Qu.:   0.0  
##  Mode  :character   Mode  :character   Median :    0   Median :   0.0  
##                                        Mean   : 1078   Mean   :  87.3  
##                                        3rd Qu.:    0   3rd Qu.:   0.0  
##                                        Max.   :99999   Max.   :4356.0  
##  hours.per.week  native.country        income         
##  Min.   : 1.00   Length:32561       Length:32561      
##  1st Qu.:40.00   Class :character   Class :character  
##  Median :40.00   Mode  :character   Mode  :character  
##  Mean   :40.44                                        
##  3rd Qu.:45.00                                        
##  Max.   :99.00
glimpse(census)
## Rows: 32,561
## Columns: 15
## $ age            <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30, 23, 32,…
## $ workclass      <chr> " State-gov", " Self-emp-not-inc", " Private", " Privat…
## $ fnlwgt         <int> 77516, 83311, 215646, 234721, 338409, 284582, 160187, 2…
## $ education      <chr> " Bachelors", " Bachelors", " HS-grad", " 11th", " Bach…
## $ education.num  <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13, 12, 11,…
## $ marital.status <chr> " Never-married", " Married-civ-spouse", " Divorced", "…
## $ occupation     <chr> " Adm-clerical", " Exec-managerial", " Handlers-cleaner…
## $ relationship   <chr> " Not-in-family", " Husband", " Not-in-family", " Husba…
## $ race           <chr> " White", " White", " White", " Black", " Black", " Whi…
## $ sex            <chr> " Male", " Male", " Male", " Male", " Female", " Female…
## $ capital.gain   <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0, …
## $ capital.loss   <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ hours.per.week <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40, 30, 50,…
## $ native.country <chr> " United-States", " United-States", " United-States", "…
## $ income         <chr> " <=50K", " <=50K", " <=50K", " <=50K", " <=50K", " <=5…
set.seed(364)
n <- nrow(census)
test_idx <- sample.int(n, size = round(0.2 * n))
train <- census[-test_idx, ]
nrow(train)
## [1] 26049
test <- census[test_idx, ]
nrow(test)
## [1] 6512
tally(~income, data = train, format = "percent")
## income
##    <=50K     >50K 
## 76.17567 23.82433
library(magrittr)
## Warning: package 'magrittr' was built under R version 4.4.3
library(dplyr)
library(ggplot2)
library(rpart)
## Warning: package 'rpart' was built under R version 4.4.3
rpart(income ~ capital.gain, data = train)
## n= 26049 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 26049 6206  <=50K (0.76175669 0.23824331)  
##   2) capital.gain< 5119 24805 5030  <=50K (0.79721830 0.20278170) *
##   3) capital.gain>=5119 1244   68  >50K (0.05466238 0.94533762) *
split <- 5095.5
train <- train %>% mutate(hi_cap_gains = capital.gain >= split)

ggplot(data = train, aes(x = capital.gain, y = income)) + geom_count(aes(color = hi_cap_gains), 
                                                                     position = position_jitter(width = 0, height = 0.1), alpha = 0.5) +
  geom_vline(xintercept = split, color = "dodgerblue", lty = 2) + 
  scale_x_log10(labels = scales::dollar)
## Warning in scale_x_log10(labels = scales::dollar): log-10 transformation
## introduced infinite values.
## Warning: Removed 23904 rows containing non-finite outside the scale range
## (`stat_sum()`).

form <- as.formula("income ~age + workclass + education + marital.status + occupation + relationship + race + sex +
                   capital.gain + capital.loss + hours.per.week")
mod_tree <- rpart(form, data = train)
mod_tree
## n= 26049 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 26049 6206  <=50K (0.76175669 0.23824331)  
##    2) relationship= Not-in-family, Other-relative, Own-child, Unmarried 14310  940  <=50K (0.93431167 0.06568833)  
##      4) capital.gain< 7073.5 14055  694  <=50K (0.95062255 0.04937745) *
##      5) capital.gain>=7073.5 255    9  >50K (0.03529412 0.96470588) *
##    3) relationship= Husband, Wife 11739 5266  <=50K (0.55140983 0.44859017)  
##      6) education= 10th, 11th, 12th, 1st-4th, 5th-6th, 7th-8th, 9th, Assoc-acdm, Assoc-voc, HS-grad, Preschool, Some-college 8199 2717  <=50K (0.66861812 0.33138188)  
##       12) capital.gain< 5095.5 7796 2321  <=50K (0.70228322 0.29771678) *
##       13) capital.gain>=5095.5 403    7  >50K (0.01736973 0.98263027) *
##      7) education= Bachelors, Doctorate, Masters, Prof-school 3540  991  >50K (0.27994350 0.72005650) *
plot(mod_tree)
text(mod_tree, use.n = TRUE, all= TRUE, cex = 0.7)

library(partykit)
## Warning: package 'partykit' was built under R version 4.4.3
## Loading required package: grid
## Loading required package: libcoin
## Warning: package 'libcoin' was built under R version 4.4.2
## Loading required package: mvtnorm
## Warning: package 'mvtnorm' was built under R version 4.4.2
plot(as.party(mod_tree))

train <- train %>%
  mutate(husband_or_wife = relationship %in% c("Husband", "Wife"), 
         college_degree = husband_or_wife & education %in%
           c("Bachelors", "Doctorate", "Masters", "Prof-school"),
         income_dtree = predict(mod_tree, type = "class"))

cg_splits <- data.frame(husband_or_wife = c(TRUE, FALSE),
                        vals = c(5095.5, 7073.5))

ggplot(data = train, aes(x = capital.gain, y = income)) + 
  geom_count(aes(color = income_dtree, shape = college_degree),
             position = position_jitter(width = 0, height = 0.1),
             alpha = 0.5) + 
  facet_wrap(~husband_or_wife) + 
  geom_vline(data = cg_splits, aes(xintercept = vals),
             color = "dodgerblue", lty = 2) + 
  scale_x_log10()
## Warning in scale_x_log10(): log-10 transformation introduced infinite values.
## Warning: Removed 23904 rows containing non-finite outside the scale range
## (`stat_sum()`).

printcp(mod_tree)
## 
## Classification tree:
## rpart(formula = form, data = train)
## 
## Variables actually used in tree construction:
## [1] capital.gain education    relationship
## 
## Root node error: 6206/26049 = 0.23824
## 
## n= 26049 
## 
##         CP nsplit rel error  xerror      xstd
## 1 0.125524      0   1.00000 1.00000 0.0110790
## 2 0.062681      2   0.74895 0.74895 0.0099573
## 3 0.038189      3   0.68627 0.68627 0.0096178
## 4 0.010000      4   0.64808 0.64808 0.0093970
train <- train %>%
  mutate(income_dtree = predict(mod_tree, type = "class"))
confusion <- tally(income_dtree ~ income, data = train, format = "count")
confusion
##             income
## income_dtree  <=50K  >50K
##        <=50K  18836  3015
##        >50K    1007  3191
sum(diag(confusion)) / nrow(train)
## [1] 0.8455987

Comment: In this case, the accuracy of the decision tree classifier is now 84.6%

#  TUNING PARAMETERS


mod_tree2 <- rpart(form, data = train, control = rpart.control(cp = 0.002))

# Question:  What is the accuracy of this more complex tree?
#  RANDOM FORESTS

library(randomForest)
## Warning: package 'randomForest' was built under R version 4.4.3
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
## The following object is masked from 'package:dplyr':
## 
##     combine
mod_forest <- randomForest(census, data = train, ntree = 201, mtry = 3)
mod_forest
## 
## Call:
##  randomForest(x = census, ntree = 201, mtry = 3, data = train) 
##                Type of random forest: unsupervised
##                      Number of trees: 201
## No. of variables tried at each split: 3
sum(diag(mod_forest$confusion)) / nrow(train)
## [1] 0
library(tibble)
## Warning: package 'tibble' was built under R version 4.4.3
importance(mod_forest) %>%
  as.data.frame() %>%
  rownames_to_column() %>%
  arrange(desc(MeanDecreaseGini))
##           rowname MeanDecreaseGini
## 1    relationship        4474.5205
## 2  marital.status        4439.0852
## 3   education.num        4423.2409
## 4       education        3567.0855
## 5             age        2850.4123
## 6      occupation        2018.1912
## 7          fnlwgt        1762.8736
## 8             sex        1659.4928
## 9  hours.per.week        1388.3706
## 10      workclass        1160.5213
## 11         income         918.6979
## 12   capital.gain         587.5626
## 13 native.country         413.2558
## 14           race         372.8524
## 15   capital.loss         235.9131