library(mdsr)
## Warning: package 'mdsr' was built under R version 4.4.3
library(dplyr)
## Warning: package 'dplyr' was built under R version 4.4.3
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(mosaic)
## Warning: package 'mosaic' was built under R version 4.4.3
## Registered S3 method overwritten by 'mosaic':
## method from
## fortify.SpatialPolygonsDataFrame ggplot2
##
## The 'mosaic' package masks several functions from core packages in order to add
## additional features. The original behavior of these functions should not be affected by this.
##
## Attaching package: 'mosaic'
## The following object is masked from 'package:Matrix':
##
## mean
## The following object is masked from 'package:ggplot2':
##
## stat
## The following objects are masked from 'package:dplyr':
##
## count, do, tally
## The following objects are masked from 'package:stats':
##
## binom.test, cor, cor.test, cov, fivenum, IQR, median, prop.test,
## quantile, sd, t.test, var
## The following objects are masked from 'package:base':
##
## max, mean, min, prod, range, sample, sum
census <- read.csv("http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", header = FALSE)
names(census) <- c("age", "workclass", "fnlwgt", "education", "education.num", "marital.status", "occupation", "relationship", "race",
"sex", "capital.gain", "capital.loss", "hours.per.week", "native.country", "income")
summary(census)
## age workclass fnlwgt education
## Min. :17.00 Length:32561 Min. : 12285 Length:32561
## 1st Qu.:28.00 Class :character 1st Qu.: 117827 Class :character
## Median :37.00 Mode :character Median : 178356 Mode :character
## Mean :38.58 Mean : 189778
## 3rd Qu.:48.00 3rd Qu.: 237051
## Max. :90.00 Max. :1484705
## education.num marital.status occupation relationship
## Min. : 1.00 Length:32561 Length:32561 Length:32561
## 1st Qu.: 9.00 Class :character Class :character Class :character
## Median :10.00 Mode :character Mode :character Mode :character
## Mean :10.08
## 3rd Qu.:12.00
## Max. :16.00
## race sex capital.gain capital.loss
## Length:32561 Length:32561 Min. : 0 Min. : 0.0
## Class :character Class :character 1st Qu.: 0 1st Qu.: 0.0
## Mode :character Mode :character Median : 0 Median : 0.0
## Mean : 1078 Mean : 87.3
## 3rd Qu.: 0 3rd Qu.: 0.0
## Max. :99999 Max. :4356.0
## hours.per.week native.country income
## Min. : 1.00 Length:32561 Length:32561
## 1st Qu.:40.00 Class :character Class :character
## Median :40.00 Mode :character Mode :character
## Mean :40.44
## 3rd Qu.:45.00
## Max. :99.00
glimpse(census)
## Rows: 32,561
## Columns: 15
## $ age <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30, 23, 32,…
## $ workclass <chr> " State-gov", " Self-emp-not-inc", " Private", " Privat…
## $ fnlwgt <int> 77516, 83311, 215646, 234721, 338409, 284582, 160187, 2…
## $ education <chr> " Bachelors", " Bachelors", " HS-grad", " 11th", " Bach…
## $ education.num <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13, 12, 11,…
## $ marital.status <chr> " Never-married", " Married-civ-spouse", " Divorced", "…
## $ occupation <chr> " Adm-clerical", " Exec-managerial", " Handlers-cleaner…
## $ relationship <chr> " Not-in-family", " Husband", " Not-in-family", " Husba…
## $ race <chr> " White", " White", " White", " Black", " Black", " Whi…
## $ sex <chr> " Male", " Male", " Male", " Male", " Female", " Female…
## $ capital.gain <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0, …
## $ capital.loss <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ hours.per.week <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40, 30, 50,…
## $ native.country <chr> " United-States", " United-States", " United-States", "…
## $ income <chr> " <=50K", " <=50K", " <=50K", " <=50K", " <=50K", " <=5…
set.seed(364)
n <- nrow(census)
test_idx <- sample.int(n, size = round(0.2 * n))
train <- census[-test_idx, ]
nrow(train)
## [1] 26049
test <- census[test_idx, ]
nrow(test)
## [1] 6512
tally(~income, data = train, format = "percent")
## income
## <=50K >50K
## 76.17567 23.82433
library(magrittr)
## Warning: package 'magrittr' was built under R version 4.4.3
library(dplyr)
library(ggplot2)
library(rpart)
## Warning: package 'rpart' was built under R version 4.4.3
rpart(income ~ capital.gain, data = train)
## n= 26049
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 26049 6206 <=50K (0.76175669 0.23824331)
## 2) capital.gain< 5119 24805 5030 <=50K (0.79721830 0.20278170) *
## 3) capital.gain>=5119 1244 68 >50K (0.05466238 0.94533762) *
split <- 5095.5
train <- train %>% mutate(hi_cap_gains = capital.gain >= split)
ggplot(data = train, aes(x = capital.gain, y = income)) + geom_count(aes(color = hi_cap_gains),
position = position_jitter(width = 0, height = 0.1), alpha = 0.5) +
geom_vline(xintercept = split, color = "dodgerblue", lty = 2) +
scale_x_log10(labels = scales::dollar)
## Warning in scale_x_log10(labels = scales::dollar): log-10 transformation
## introduced infinite values.
## Warning: Removed 23904 rows containing non-finite outside the scale range
## (`stat_sum()`).

form <- as.formula("income ~age + workclass + education + marital.status + occupation + relationship + race + sex +
capital.gain + capital.loss + hours.per.week")
mod_tree <- rpart(form, data = train)
mod_tree
## n= 26049
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 26049 6206 <=50K (0.76175669 0.23824331)
## 2) relationship= Not-in-family, Other-relative, Own-child, Unmarried 14310 940 <=50K (0.93431167 0.06568833)
## 4) capital.gain< 7073.5 14055 694 <=50K (0.95062255 0.04937745) *
## 5) capital.gain>=7073.5 255 9 >50K (0.03529412 0.96470588) *
## 3) relationship= Husband, Wife 11739 5266 <=50K (0.55140983 0.44859017)
## 6) education= 10th, 11th, 12th, 1st-4th, 5th-6th, 7th-8th, 9th, Assoc-acdm, Assoc-voc, HS-grad, Preschool, Some-college 8199 2717 <=50K (0.66861812 0.33138188)
## 12) capital.gain< 5095.5 7796 2321 <=50K (0.70228322 0.29771678) *
## 13) capital.gain>=5095.5 403 7 >50K (0.01736973 0.98263027) *
## 7) education= Bachelors, Doctorate, Masters, Prof-school 3540 991 >50K (0.27994350 0.72005650) *
plot(mod_tree)
text(mod_tree, use.n = TRUE, all= TRUE, cex = 0.7)

library(partykit)
## Warning: package 'partykit' was built under R version 4.4.3
## Loading required package: grid
## Loading required package: libcoin
## Warning: package 'libcoin' was built under R version 4.4.2
## Loading required package: mvtnorm
## Warning: package 'mvtnorm' was built under R version 4.4.2
plot(as.party(mod_tree))

train <- train %>%
mutate(husband_or_wife = relationship %in% c("Husband", "Wife"),
college_degree = husband_or_wife & education %in%
c("Bachelors", "Doctorate", "Masters", "Prof-school"),
income_dtree = predict(mod_tree, type = "class"))
cg_splits <- data.frame(husband_or_wife = c(TRUE, FALSE),
vals = c(5095.5, 7073.5))
ggplot(data = train, aes(x = capital.gain, y = income)) +
geom_count(aes(color = income_dtree, shape = college_degree),
position = position_jitter(width = 0, height = 0.1),
alpha = 0.5) +
facet_wrap(~husband_or_wife) +
geom_vline(data = cg_splits, aes(xintercept = vals),
color = "dodgerblue", lty = 2) +
scale_x_log10()
## Warning in scale_x_log10(): log-10 transformation introduced infinite values.
## Warning: Removed 23904 rows containing non-finite outside the scale range
## (`stat_sum()`).

printcp(mod_tree)
##
## Classification tree:
## rpart(formula = form, data = train)
##
## Variables actually used in tree construction:
## [1] capital.gain education relationship
##
## Root node error: 6206/26049 = 0.23824
##
## n= 26049
##
## CP nsplit rel error xerror xstd
## 1 0.125524 0 1.00000 1.00000 0.0110790
## 2 0.062681 2 0.74895 0.74895 0.0099573
## 3 0.038189 3 0.68627 0.68627 0.0096178
## 4 0.010000 4 0.64808 0.64808 0.0093970
train <- train %>%
mutate(income_dtree = predict(mod_tree, type = "class"))
confusion <- tally(income_dtree ~ income, data = train, format = "count")
confusion
## income
## income_dtree <=50K >50K
## <=50K 18836 3015
## >50K 1007 3191
sum(diag(confusion)) / nrow(train)
## [1] 0.8455987
Comment: In this case, the accuracy of the decision tree classifier is now 84.6%