We will investigate the OKCupid Data and see if we can predict if a user is a STEM major or not, based on the data provided. First we will perform some EDA to look at the data, and then we will use the Tidymodels framework to easily create machine learning models. If you wish to view the results, use the contents tab on the left to skip to the Performance on Test Set section.

library(modeldata)
library(tidyverse)
library(skimr)
library(tidymodels)

1 EDA

1.1 Basic EDA

data("okc")

skim(okc)
Data summary
Name okc
Number of rows 59855
Number of columns 6
_______________________
Column type frequency:
character 2
Date 1
factor 1
numeric 2
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
diet 24360 0.59 5 19 0 18 0
location 0 1.00 4 19 0 135 0

Variable type: Date

skim_variable n_missing complete_rate min max median n_unique
date 0 1 2011-06-27 2012-07-01 2012-06-27 371

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
Class 0 1 FALSE 2 oth: 50316, ste: 9539

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
age 0 1 32.34 9.46 18 26 30 37 110 ▇▂▁▁▁
height 2 1 68.29 3.99 1 66 68 71 95 ▁▁▁▇▁
okc %>%
  ggplot(aes(x=age))+
  geom_histogram(binwidth = 1,fill="pink")+
  theme_minimal()+
  labs(title="OKC Users Tends To Be Young, The Most common Age Being 26",subtitle="Age Distribution of OKC Users",x="Age",y="Count")+
  geom_vline(xintercept = 26, col="red",lty=2)

okc %>%
  filter(height > 30)%>%
  mutate(height = height/0.39370 )%>%
  ggplot(aes(x=height))+
  geom_histogram(bins=30,fill="pink")+
  theme_minimal()+
  labs(title="The Average OKC User was 175cm tall",subtitle="Height Distribution of OKC Users",x="Height (cm)",y="Count")

okc %>%
  mutate(diet = str_remove(diet,"strictly"),
         diet = str_remove(diet,"mostly"),
         diet = str_remove(diet, " "),
         diet = str_to_title(diet),
         )%>%
  count(diet)%>%
  arrange(desc(n))%>%
  ggplot(aes(x=n,y=reorder(diet,n),fill=diet))+
  geom_col(show.legend = F)+
  geom_text(aes(label=n),hjust=-.1)+
  theme_minimal()+
  lims(x=c(0,31000))+
  labs(title = "OKC Users Were Not Picky Eaters",subtitle="Total Count of Diets",x="Count",y="Diet")

okc %>%
  mutate(location = fct_lump_n(location,n=10) )%>%
  mutate (location =  str_to_title(location))%>%
  count(location) %>%
  arrange(-n)%>%
  ggplot(aes(x=n,y=reorder(location,n),fill=location))+
  geom_col(show.legend = F)+
  geom_text(aes(label=n),hjust=-.1)+
  theme_minimal()+
  lims(x=c(0,35000))+
  labs(title = "Most OKC Users were from San Francisco",subtitle="Top 10 Most Common Locations",x="Count",y="Location")

okc %>%
  count(Class)%>%
  mutate(prop=n/sum(n)) %>%
  ggplot(aes(x=Class,y=prop,fill=Class))+
  geom_col(show.legend = FALSE)+
  theme_minimal()+
  scale_y_continuous(labels=scales::percent_format(),limits = c(0,1))+
  geom_text(aes(label= paste(round(prop*100),"%"),vjust=-.1) )+
  labs(title="STEM Students were in the Minority",subtitle="Percent of Users in STEM",x="Major",y="Percent")

Were the age demographics different for different locations?

okc %>%
  mutate(location = fct_lump_n(location,n=10) )%>%
  mutate (location =  str_to_title(location)) %>%
  ggplot(aes(x=reorder(location,age),y=age,fill=location))+
  #geom_violin(show.legend = F,draw_quantiles = c(.25,.50,.75))+
  geom_boxplot(show.legend = F)+
  theme_minimal()+
  labs(title="Age Distribution of 10 Most Populated Locations",x="Location",y="Age")

okc %>%
  mutate(diet = str_remove(diet,"strictly"),
         diet = str_remove(diet,"mostly"),
         diet = str_remove(diet, " "),
         diet = str_to_title(diet),
     location = fct_lump_n(location,n=15),
     location =  str_to_title(location)
  )%>%
  group_by(location)%>%
  count(diet)%>%
  ungroup()%>%
  ggplot(aes(x=n,   y=tidytext::reorder_within(diet,n,location),fill=diet ))+
  geom_col(show.legend = F)+
  facet_wrap(.~location,scales = "free")+
  tidytext::scale_y_reordered()+
  #geom_text(aes(label=n,),hjust=.5)+
  theme_minimal()+
  labs(title="Most Common Diets by Top 16 Most Populated Areas",x=NULL,y=NULL)

okc %>%
  filter(height > 30)%>%
  na.omit(height)%>%
  group_by(location)%>%
mutate( avg_height = mean(height),
          avg_age = mean(age)
  )%>%
  ungroup()%>%
  distinct(location,.keep_all = TRUE) %>%
  ggplot(aes(x=avg_age,y=avg_height))+
  geom_point(col="red")+
  geom_text(aes(label=location))+
  theme_minimal()+
  labs(title="Average Age and Height of OKC Users by Location",x="Average Age",y="Average Height")

okc %>%
  #filter( !is.na(diet)) %>%
  mutate(diet = str_remove(diet,"strictly"),
         diet = str_remove(diet,"mostly"),
         diet = str_remove(diet, " "),
         diet = str_to_title(diet),
     location = fct_lump_n(location,n=15),
     location =  str_to_title(location)
  ) %>%
  group_by(diet,Class)%>%
  count(diet)%>%
  ungroup(diet)%>%
  mutate(prop=n/sum(n)) %>%
  ggplot(aes(x=reorder(diet,-prop),y=prop,fill=Class))+
  geom_col(position = position_dodge())+
  geom_text(aes(label=  paste0(  round(prop*100,1),"%" ) ),position = position_dodge(width = 1),vjust=-.5)+
  scale_y_continuous(labels = scales::percent_format())+
  labs(title="Dietary Difference Between Major",subtitle = "Proportion of Diet Amongst Major",x="Diet",y="Percent",fill="Class:")+
    theme_minimal()+
  theme(legend.position = "top",
        panel.grid.major.x = element_blank() )

ggplot(okc,aes(x=Class,y=age,fill=Class))+
  geom_violin(draw_quantiles = c(.25,.5,.75),show.legend = F)+
  theme_minimal()+
  labs(title="Age Distribution By Major")
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

okc %>%
  filter(height > 25) %>%
ggplot(aes(x=Class,y=height,fill=Class))+
  geom_violin(draw_quantiles = c(.25,.5,.75),show.legend = F)+
  #geom_boxplot(show.legend = F)+
  theme_minimal()+
  labs(title="Height Distribution By Major",y="Height (Inches)")
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

okc %>%
  filter(height > 25) %>%
  mutate(diet = str_remove(diet,"strictly"),
         diet = str_remove(diet,"mostly"),
         diet = str_remove(diet, " "),
         diet = str_to_title(diet),
  )%>%
    ggplot(aes(x=reorder(diet,height),y=height,fill=diet))+
    geom_violin(  draw_quantiles = c(.25,.5,.75),show.legend = FALSE)+
    theme_minimal()+
   labs(title="Height Difference Across Diets",x="Diet",y="Height (Inches)")
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

okc %>%
  na.omit() %>%
  filter(height > 25) %>%
  mutate(diet = str_remove(diet,"strictly"),
         diet = str_remove(diet,"mostly"),
         diet = str_remove(diet, " "),
         diet = str_to_title(diet),
  )%>%
    ggplot(aes(x=diet,y=height,fill=interaction(diet,Class) )   ) + #Use fill=interaction(Class,diet) for better colouring - doesn't work here due to NAs
    geom_violin(  draw_quantiles = c(.25,.5,.75))+
    theme_minimal()+
   labs(title="Height Difference Majors Across Diets",subtitle ="Left is STEM, Right is Others" ,x="Diet",y="Height (Inches)")+
   scale_fill_manual(values=c("darkslategray4","firebrick4","palegreen4","slateblue4","royalblue3","darkgoldenrod4",
              "darkslategray1","firebrick2","palegreen1","slateblue2","royalblue1","darkgoldenrod1"))+
  theme(legend.position = "right")+
  guides(fill=guide_legend(ncol=2))
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

okc %>%
  na.omit() %>%
  filter(height > 25) %>%
  mutate(diet = str_remove(diet,"strictly"),
         diet = str_remove(diet,"mostly"),
         diet = str_remove(diet, " "),
         diet = str_to_title(diet),
  )%>%
    ggplot(aes(x= tidytext::reorder_within(diet,height,Class)
               ,y=height,fill=diet)) + #Use fill=interaction(Class,diet) for better colouring - doesn't work here due to NAs
    #geom_violin(  draw_quantiles = c(.25,.5,.75))+
  geom_boxplot()+
  facet_wrap(.~Class,nrow=2,scales = "free_x")+
    theme_bw()+
   labs(title="STEM Majors Who Followed A Kosher Diet Were Tallest For That Group,\nWhile Vegans Were The Shortest",x="Diet",y="Height (Inches)")+
   scale_fill_manual(values=c("darkslategray4","firebrick4","palegreen4","slateblue4","royalblue3","darkgoldenrod4"))+
  theme(legend.position = "none")+
  tidytext::scale_x_reordered()

Seeing which combination of diet and class has the tallest individuals on average.

okc %>%
  #na.omit() %>%
  filter(height > 25) %>%
  mutate(diet = str_remove(diet,"strictly"),
         diet = str_remove(diet,"mostly"),
         diet = str_remove(diet, " "),
         diet = str_to_title(diet),
  ) %>%
  group_by(Class,diet)%>%
  summarise(avg_height = mean(height),.groups="keep") %>%
  mutate( avg_height = avg_height/0.39370) %>%
  arrange(-avg_height) %>%
  ungroup() %>%
  pivot_wider(names_from = Class,values_from=avg_height)%>%
  mutate( avg = (stem+other)/2 )%>%
  arrange(-avg)
diet stem other avg
Halal 176.8234 176.7685 176.7960
Kosher 179.7054 172.9669 176.3362
Anything 176.8687 173.3305 175.0996
NA 176.5763 172.7735 174.6749
Other 175.9794 173.1469 174.5632
Vegetarian 176.0799 170.9882 173.5341
Vegan 173.6186 171.7596 172.6891
okc %>%
  na.omit() %>%
  filter(height > 25) %>%
  mutate(diet = str_remove(diet,"strictly"),
         diet = str_remove(diet,"mostly"),
         diet = str_remove(diet, " "),
         diet = str_to_title(diet),
  )%>%
    ggplot(aes(x=diet,y=age,fill=interaction(diet,Class) )   ) + #Use fill=interaction(Class,diet) for better colouring - doesn't work here due to NAs
    geom_violin(  draw_quantiles = c(.25,.5,.75))+
    theme_minimal()+
   labs(title="Age by Diet and Class",x="Diet",y="Age")+
   scale_fill_manual(values=c("darkslategray4","firebrick4","palegreen4","slateblue4","royalblue3","darkgoldenrod4",
              "darkslategray1","firebrick2","palegreen1","slateblue2","royalblue1","darkgoldenrod1"))+
  theme(legend.position = "right")+
  guides(fill=guide_legend(ncol=2))
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

Is there a difference between major in different locations?

okc %>%
  #filter( !is.na(diet)) %>%
  mutate(location = fct_lump_prop(location,prop=0.01),
         location =  str_to_title(location)
  ) %>%
  group_by(location,Class)%>%
  count(location)%>%
  ungroup(Class)%>%
  mutate(prop=n/sum(n)) %>%
  ungroup() %>%
  ggplot(aes(x= location,y=prop,fill=Class))+
  geom_col(position = position_stack() )+
  geom_text(aes(label=  paste0(  round(prop*100,1),"%" ) ), position = position_stack(vjust = 0.5)  )+
  scale_y_continuous(labels = scales::percent_format())+
  labs(title="Major Difference Between Top 99% of Most Populated Locations",x="Location",y="Percent",fill="Class:")+
  theme_minimal()+
  theme(legend.position = "top",
        panel.grid.major.x = element_blank() )+
  coord_flip()+
  guides(fill = guide_legend(reverse=TRUE))

okc %>%
  #filter( !is.na(diet)) %>%
  mutate(location = fct_lump_prop(location,prop=0.05),
         location =  str_to_title(location)
  ) %>%
  group_by(location,Class)%>%
  count(location)%>%
  ungroup(Class)%>%
  mutate(prop=n/sum(n)) %>%
  ungroup() %>%
  ggplot(aes(x= location,y=prop,fill=Class))+
  geom_col(position = position_stack() )+
  geom_text(aes(label=  paste0(  round(prop*100,1),"%" ) ), position = position_stack(vjust = 0.5)  )+
  scale_y_continuous(labels = scales::percent_format())+
  labs(title="Major Difference Between Top 95% of Most Populated Locations",x="Location",y="Percent",fill="Class:")+
  theme_minimal()+
  theme(legend.position = "top",
        panel.grid.major.x = element_blank() )+
  coord_flip()+
  guides(fill = guide_legend(reverse=TRUE))

2 Tidy Models

2.1 Creating Model Data

Armed with the information gathered from the EDA process, we can create a cleaned dataset to model on.

okc_clean = okc %>%
  filter( height>25) %>%
  mutate(diet = str_remove(diet,"strictly"),
         diet = str_remove(diet,"mostly"),
         diet = str_remove(diet, " "),
         diet = str_to_title(diet),
         diet = factor(diet),
         location = fct_lump_prop(location,prop=0.05), #Lump people who appear less than 5% as Other
         location =  str_to_title(location),
         location = factor(location))%>%
  select(-date)
skim(okc_clean)
Data summary
Name okc_clean
Number of rows 59847
Number of columns 5
_______________________
Column type frequency:
factor 3
numeric 2
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
diet 24355 0.59 FALSE 6 Any: 27841, Veg: 4976, Oth: 1785, Veg: 698
location 0 1.00 FALSE 4 San: 31061, Oth: 17364, Oak: 7213, Ber: 4209
Class 0 1.00 FALSE 2 oth: 50310, ste: 9537

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
age 0 1 32.34 9.46 18 26 30 37 110 ▇▂▁▁▁
height 0 1 68.30 3.94 26 66 68 71 95 ▁▁▆▇▁
set.seed(8888)

okc_split = initial_split(okc_clean,strata = Class)

okc_train = training(okc_split)
okc_test = testing(okc_split)

2.2 Create A Recipie

okc_rec = recipe(Class~.,data=okc_train) %>%
  themis::step_downsample(Class) %>%
  step_unknown(diet) %>%
  step_dummy(all_nominal(),-all_outcomes()) %>%
  step_zv(all_predictors())
## Registered S3 methods overwritten by 'themis':
##   method               from   
##   bake.step_downsample recipes
##   bake.step_upsample   recipes
##   prep.step_downsample recipes
##   prep.step_upsample   recipes
##   tidy.step_downsample recipes
##   tidy.step_upsample   recipes
tree_rec = recipe(Class~.,data=okc_train) %>%
  step_unknown(diet)

2.3 Set Model Spec

lgst_mod = logistic_reg() %>%
  set_engine("glm")

tree_mod = decision_tree(cost_complexity = tune(),
                          min_n = tune(),
                          tree_depth = tune()) %>%
  set_engine("rpart")%>%
  set_mode("classification")

2.4 Creating Worflow

lgst_wf = workflow() %>%
    add_model(lgst_mod) %>%
    add_recipe(okc_rec)

tree_wf = workflow() %>%
  add_model(tree_mod)%>%
  add_recipe(tree_rec)

2.4.1 Finding the Hyperparameters For Tree Model

#doParallel::registerDoParallel()

start = Sys.time()

set.seed(8888)
tree_res = tree_wf %>%
  tune_grid(    resamples = vfold_cv(okc_train,strata = Class),
                grid =  grid_max_entropy(cost_complexity(), min_n(),tree_depth(),size=20 ),
                control = control_grid(save_pred = TRUE))

#doParallel:::stopImplicitCluster()

end = Sys.time()

paste("Time Taken:",end-start)
## [1] "Time Taken: 1.93076254924138"
tree_res %>%
  collect_metrics()
cost_complexity tree_depth min_n .metric .estimator mean n std_err .config
0.0002728 8 17 accuracy binary 0.8406407 10 0.0000232 Model01
0.0002728 8 17 roc_auc binary 0.5000000 10 0.0000000 Model01
0.0000001 1 24 accuracy binary 0.8406407 10 0.0000232 Model02
0.0000001 1 24 roc_auc binary 0.5000000 10 0.0000000 Model02
0.0000042 6 3 accuracy binary 0.8405962 10 0.0000492 Model03
0.0000042 6 3 roc_auc binary 0.6079312 10 0.0033329 Model03
0.0787769 12 20 accuracy binary 0.8406407 10 0.0000232 Model04
0.0787769 12 20 roc_auc binary 0.5000000 10 0.0000000 Model04
0.0000000 2 39 accuracy binary 0.8406407 10 0.0000232 Model05
0.0000000 2 39 roc_auc binary 0.5000000 10 0.0000000 Model05
0.0425398 3 20 accuracy binary 0.8406407 10 0.0000232 Model06
0.0425398 3 20 roc_auc binary 0.5000000 10 0.0000000 Model06
0.0000000 13 2 accuracy binary 0.8339795 10 0.0005051 Model07
0.0000000 13 2 roc_auc binary 0.6326370 10 0.0028302 Model07
0.0000028 11 4 accuracy binary 0.8379674 10 0.0001622 Model08
0.0000028 11 4 roc_auc binary 0.6370359 10 0.0037158 Model08
0.0001168 13 39 accuracy binary 0.8406407 10 0.0000232 Model09
0.0001168 13 39 roc_auc binary 0.5000000 10 0.0000000 Model09
0.0000150 8 34 accuracy binary 0.8406407 10 0.0000232 Model10
0.0000150 8 34 roc_auc binary 0.5000000 10 0.0000000 Model10
0.0000000 14 33 accuracy binary 0.8397497 10 0.0002963 Model11
0.0000000 14 33 roc_auc binary 0.6263418 10 0.0040858 Model11
0.0000000 13 16 accuracy binary 0.8381233 10 0.0001983 Model12
0.0000000 13 16 roc_auc binary 0.6349699 10 0.0032979 Model12
0.0344537 7 31 accuracy binary 0.8406407 10 0.0000232 Model13
0.0344537 7 31 roc_auc binary 0.5000000 10 0.0000000 Model13
0.0356874 11 37 accuracy binary 0.8406407 10 0.0000232 Model14
0.0356874 11 37 roc_auc binary 0.5000000 10 0.0000000 Model14
0.0000000 8 35 accuracy binary 0.8406407 10 0.0000232 Model15
0.0000000 8 35 roc_auc binary 0.5000000 10 0.0000000 Model15
0.0000001 8 16 accuracy binary 0.8404848 10 0.0000888 Model16
0.0000001 8 16 roc_auc binary 0.6133944 10 0.0067255 Model16
0.0185553 5 2 accuracy binary 0.8406407 10 0.0000232 Model17
0.0185553 5 2 roc_auc binary 0.5000000 10 0.0000000 Model17
0.0000337 15 15 accuracy binary 0.8380788 10 0.0002459 Model18
0.0000337 15 15 roc_auc binary 0.6258610 10 0.0048791 Model18
0.0000000 1 4 accuracy binary 0.8406407 10 0.0000232 Model19
0.0000000 1 4 roc_auc binary 0.5000000 10 0.0000000 Model19
0.0000019 12 26 accuracy binary 0.8397496 10 0.0001654 Model20
0.0000019 12 26 roc_auc binary 0.6187251 10 0.0138764 Model20
tree_res %>%
  collect_predictions() %>%
  group_by(id) %>%
  roc_curve(truth=Class,.pred_stem) %>%
  ggplot(aes(x=1-sensitivity,y=specificity,col=id))+
  geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
  geom_abline(lty = 2, color = "gray80", size = 1.5) +
  theme_minimal()+
  coord_equal()

We do not see great results.

tree_res %>%
  select_best(metric="roc_auc")
cost_complexity tree_depth min_n .config
2.8e-06 11 4 Model08

2.4.2 Estimate Logistic Regression

lgst_fit = fit(lgst_wf,data=okc_train)
coefs = lgst_fit %>%
  pull_workflow_fit() %>%
  tidy()

coefs
term estimate std.error statistic p.value
(Intercept) 7.1115601 0.3380275 21.0384057 0.0000000
age -0.0002797 0.0019638 -0.1424038 0.8867610
height -0.1024688 0.0046800 -21.8949886 0.0000000
diet_Halal 0.0515111 0.4466523 0.1153270 0.9081859
diet_Kosher 0.8325171 0.4272224 1.9486740 0.0513344
diet_Other 0.2191746 0.1094641 2.0022506 0.0452578
diet_Vegan 0.3649844 0.1702970 2.1432231 0.0320952
diet_Vegetarian 0.1336157 0.0652959 2.0463123 0.0407257
diet_unknown 0.1185543 0.0368114 3.2205865 0.0012793
location_Oakland 0.2397544 0.0868536 2.7604429 0.0057723
location_Other -0.0805340 0.0742326 -1.0848876 0.2779714
location_San.Francisco -0.2217720 0.0710128 -3.1229854 0.0017903
coefs %>%
  filter(term!="(Intercept)")%>%
  mutate(Significant = if_else(p.value < 0.05,"TRUE","FALSE")  ) %>%
ggplot(aes(y=term,x=estimate,col=Significant))+
  geom_point()+
  geom_errorbar(aes(xmax=estimate+1.96*std.error,
                xmin=estimate-1.96*std.error
  ))+
  geom_vline(xintercept=0,lty=4,col="grey50")+
  theme_minimal()+
  labs(title="Coefficients Estimates",subtitle="Log Odds Ratio",col="Signifcant\nat 5% level")

What this is telling us, is that a unit increase in height, results in a unit increase of -0.105 for the log odds ratio of someone having a STEM major. The odds ratio is 0.9, this means that we expect to see a 10% decrease in odds ratio of someone being STEM, for a unit increase in height.

We see that Age is not a significant variable in determining the major of a person.

The baseline of the model is someone from Berkeley and has a diet of Anything. From a diet perspective, we see that someone being: Vegetarian, Vegan Unknown or Other, increases the odds of someone having a STEM Major. We do not see evidence that Kosher or Halal are different to Anything.

Location wise, Being at Oaklands, rather than Berkeley gives 17% increase in the odds of someone being STEM. Being at San Francisco is a 29% decrease in the oods of someone having a STEM Major, being from neither of these place results in a 20% decrease in the odds of someone being STEM.

coefs %>%
  filter(term !="(Intercept)")%>%
  mutate(Odds_Ratio = exp(estimate)) %>%
  select(term,Odds_Ratio)
term Odds_Ratio
age 0.9997204
height 0.9026063
diet_Halal 1.0528608
diet_Kosher 2.2990986
diet_Other 1.2450486
diet_Vegan 1.4404915
diet_Vegetarian 1.1429536
diet_unknown 1.1258680
location_Oakland 1.2709370
location_Other 0.9226235
location_San.Francisco 0.8010980
coefs %>%
  filter(term != "(Intercept)") %>%
  mutate(Significant = if_else(p.value < 0.05, "TRUE", "FALSE")) %>%
  ggplot(aes(
    y = term,
    x = exp(estimate),
    col = Significant
  )) +
  geom_point() +
  geom_errorbar(aes(xmax = exp(estimate + 1.96 * std.error),
                xmin = exp(estimate - 1.96 * std.error))  
                )+
  geom_vline(xintercept = 1,
             lty = 4,
             col = "grey50") +
  theme_minimal() +
  labs(title = "Odds Ratio",col="Signifcant\nat 5% level")

2.4.3 Evaluating Logistic Regression Performance

We will use resampling to get an estimate on how the model performed on the training set.

set.seed(8888)
lgst_train_rs = lgst_wf %>%
  fit_resamples(
   resamples = vfold_cv(okc_train,strata = Class),
   metrics = metric_set(accuracy,roc_auc, sens, spec),
    control = control_resamples(save_pred = TRUE)
 )
lgst_train_rs %>%
  collect_metrics()
.metric .estimator mean n std_err
accuracy binary 0.5713578 10 0.0031953
roc_auc binary 0.6221339 10 0.0021738
sens binary 0.6260325 10 0.0048708
spec binary 0.5609937 10 0.0035646
lgst_train_rs %>%
  collect_predictions()%>%
  group_by(id)  %>%
  roc_curve(truth=Class,.pred_stem) %>%
  ggplot(aes(x=1-sensitivity,y=specificity,col=id))+
  geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
  geom_abline(lty = 2, color = "gray80", size = 1.5) +
  theme_minimal()+
  coord_equal()

Below we can see the ROC Curve, we don’t get good results on the tree model

lgst_train_rs %>%
  collect_predictions()%>%
  group_by(id)%>% 
  roc_curve(truth=Class,.pred_stem) %>%
    mutate(model="Logistic") %>%
  
  bind_rows(
    tree_res %>%
  collect_predictions()%>%
  group_by(id)%>% 
  roc_curve(truth=Class,.pred_stem) %>%
    mutate(model="Decision Tree")
) %>%
  ggplot(aes(x=1-sensitivity,y=specificity,col=model))+
  geom_path(alpha = 0.7, size = 1) +
  geom_abline(lty = 2, color = "gray80", size = 1.5) +
  theme_minimal()+
  coord_equal()

2.4.4 Using Best Tree

best_tree = tree_res %>% select_best("roc_auc")

final_tree_wf = 
  tree_wf %>%
  finalize_workflow(best_tree)

final_tree_wf
## == Workflow ================================================================================================================================================================
## Preprocessor: Recipe
## Model: decision_tree()
## 
## -- Preprocessor ------------------------------------------------------------------------------------------------------------------------------------------------------------
## 1 Recipe Step
## 
## * step_unknown()
## 
## -- Model -------------------------------------------------------------------------------------------------------------------------------------------------------------------
## Decision Tree Model Specification (classification)
## 
## Main Arguments:
##   cost_complexity = 2.81310986908408e-06
##   tree_depth = 11
##   min_n = 4
## 
## Computational engine: rpart
library(vip)

final_tree = final_tree_wf %>% fit(data=okc_train)

pull_final_tree= final_tree %>%
  pull_workflow_fit()

pull_final_tree %>%
  vip()

It appears from the plot, that height is the most important variable for the decision tree, followed by height.

3 Performance on Test Set

We can see the roc_auc and accuracy.

predict(lgst_fit,new_data = okc_test,type="prob") %>%
   bind_cols(  okc_test %>% select(Class)  ) %>%
   roc_auc(truth=Class, .pred_stem) %>%
  mutate(model="Logistic") %>%
  bind_rows(

predict(final_tree,new_data = okc_test,type="prob") %>%
   bind_cols(  okc_test %>% select(Class)  ) %>%
   roc_auc(truth=Class, .pred_stem)%>%
  mutate(model="Decision Tree")
) %>%
  select(model,.estimate,.metric) %>%
  
  bind_rows(


predict(lgst_fit,new_data = okc_test)   %>%
   bind_cols(  okc_test %>% select(Class)  ) %>%
   accuracy(truth=Class, .pred_class) %>%
  mutate(model="Logistic") %>%
  bind_rows(

predict(final_tree,new_data = okc_test)   %>%
   bind_cols(  okc_test %>% select(Class)  ) %>%
   accuracy(truth=Class, .pred_class) %>%
  mutate(model="Decision Tree")
) %>%
  select(model,.estimate,.metric)

) %>%
  pivot_wider(names_from = .metric,values_from=.estimate)
model roc_auc accuracy
Logistic 0.6343331 0.5860571
Decision Tree 0.6589016 0.8389145

In terms of roc_auc and accuracy, the decision tree is superior over the logistic model.

predict(lgst_fit,new_data = okc_test,type="prob") %>%
   bind_cols(  okc_test %>% select(Class)  ) %>%
   roc_curve(truth=Class, .pred_stem)%>%
  mutate(model="Logistic") %>%
  bind_rows(

predict(final_tree,new_data = okc_test,type="prob") %>%
   bind_cols(  okc_test %>% select(Class)  ) %>%
   roc_curve(truth=Class, .pred_stem)%>%
  mutate(model="Decision Tree")
) %>%
  ggplot(aes(x=1-specificity,y=sensitivity,col=model))+
  geom_path(size=2)+
  geom_abline(slope=1,col="grey50",size=1.5,lty=8)+
  theme_minimal()+
  coord_equal()

3.0.1 Confusion Matrices

predict(lgst_fit,new_data = okc_test)   %>%
   bind_cols(  okc_test %>% select(Class)  ) %>%
  conf_mat(truth=Class,.pred_class) %>%
    autoplot(type="heatmap")

predict(final_tree,new_data = okc_test)   %>%
   bind_cols(  okc_test %>% select(Class)  ) %>%
  conf_mat(truth=Class,.pred_class) %>%
  autoplot(type="heatmap")

However, upon further inspection, we see that the tree is just predicting other many times. The base line is

paste0("Base Accuracy: ",round((12546+31)/nrow(okc_test) *100,1),"%")
## [1] "Base Accuracy: 84.1%"

So in-fact, we did worse than always predicting the most common occurrence in our training set, which was to predict that someone was not a STEM major.

4 Post-Hoc Analysis

summary(okc_train)
##       age                 diet           height              location    
##  Min.   : 18.00   Anything  :20839   Min.   :26.0   Berkeley     : 3156  
##  1st Qu.: 26.00   Halal     :   58   1st Qu.:66.0   Oakland      : 5391  
##  Median : 30.00   Kosher    :   90   Median :68.0   Other        :13040  
##  Mean   : 32.36   Other     : 1319   Mean   :68.3   San Francisco:23299  
##  3rd Qu.: 37.00   Vegan     :  532   3rd Qu.:71.0                        
##  Max.   :110.00   Vegetarian: 3745   Max.   :95.0                        
##                   NA's      :18303                                       
##    Class      
##  stem : 7153  
##  other:37733  
##               
##               
##               
##               
## 

We create a fake dataset to see what the models are predicting.

fake_data_2 = expand_grid(
  age = (18:70),
  height = (60:80),
  location = c("San Francisco", "Oakland", "Other", "Berkeley"),
  diet = c("Halal", "Kosher","Anything","Other",NA, "Vegan", "Vegetarian")
)
predict(lgst_fit,new_data=fake_data_2) %>%
  rename(Pred =  .pred_class) %>%
  bind_cols(fake_data_2) %>%
  ggplot(aes(x=age,y=height,fill=Pred))+
  geom_tile()+
  theme_bw()+
  facet_grid(diet~location)+
  labs(title="Logistic Model Predictions of Various Users\nWith Different Age, Height, Location and Diet")

predict(final_tree,new_data=fake_data_2) %>%
  rename(Pred =  .pred_class) %>%
  bind_cols(fake_data_2) %>%
  ggplot(aes(x=age,y=height,fill=Pred))+
  geom_tile()+
  theme_bw()+
  facet_grid(diet~location)+
  labs(title="Decision Tree Model Predictions of Various Users\nWith Different Age, Height, Location and Diet")

lgst_cmb_preds= predict(lgst_fit,new_data = okc_test) %>%
  bind_cols(okc_test) %>%
  mutate( Result =   case_when( .pred_class == "stem"  & Class=="stem"  ~ "TP",
                                 .pred_class == "other"  & Class=="other"  ~ "TN",
                                 .pred_class == "stem"  & Class=="other"  ~ "FP",
                                  .pred_class == "other"  & Class=="stem"  ~ "FN"
                                )) %>%
  rename(Pred = .pred_class)
  ggplot()+
  geom_point( data= lgst_cmb_preds%>%filter(Result %in% c("FN","FP")) ,mapping=aes(x=age,y=height),col="grey",alpha=0.8)+
  geom_point( data= lgst_cmb_preds%>%filter(!Result %in% c("FN","FP")) ,mapping=aes(x=age,y=height,col=Result),alpha=0.8)+
  theme_bw()+
  facet_grid(diet~location)+
  scale_color_manual(values = c("royalblue4","indianred4"))+
  labs(title="Logistic Model with True Positives and True Negatives")

  ggplot()+
  geom_point( data= lgst_cmb_preds%>%filter(!Result %in% c("FN","FP")) ,mapping=aes(x=age,y=height),col="grey",alpha=0.8)+
  geom_point( data= lgst_cmb_preds%>%filter(Result %in% c("FN","FP")) ,mapping=aes(x=age,y=height,col=Result),alpha=0.8)+
  theme_bw()+
  facet_grid(diet~location)+
  scale_color_manual(values = c("royalblue1","indianred1"))+
  labs(title="Logistic Model with False Positives and False Negatives")

We see that the Logistic Model harshly classifies based on height.

ggplot(okc_train,aes(x=age,y=height,col=Class))+
  geom_point(alpha=0.8)+
  theme_bw()+
  facet_grid(diet~location)+
  scale_color_manual(values = c("seagreen","mediumpurple3"))+
  labs(title="OKC Training Set")

ggplot(okc_test,aes(x=age,y=height,col=Class))+
  geom_point(alpha=0.8)+
  theme_bw()+
  facet_grid(diet~location)+
  scale_color_manual(values = c("seagreen","mediumpurple3"))+
  labs(title="OKC Test Set")

predict(lgst_fit,new_data=fake_data_2,type="prob") %>%
  #rename(Pred =  .pred_class) %>%
  bind_cols(fake_data_2) %>%
  ggplot(aes(x=age,y=height,fill=.pred_other))+
  geom_tile()+
  theme_bw()+
  facet_grid(diet~location)+
  scale_fill_gradientn(colours =  RColorBrewer::brewer.pal(n = 11, name = 'PRGn'))+
  #scale_fill_brewer(palette = RColorBrewer::brewer.pal(n = 10, name = 'PRGn'))+
  labs(title="Logistic Model Probabilities of Various Users\nWith Different Age, Height, Location and Diet")

predict(final_tree,new_data=fake_data_2,type="prob") %>%
  #rename(Pred =  .pred_class) %>%
  bind_cols(fake_data_2) %>%
  ggplot(aes(x=age,y=height,fill=.pred_stem))+
  geom_tile()+
  theme_bw()+
  facet_grid(diet~location)+
  scale_fill_gradientn(colours =  RColorBrewer::brewer.pal(n=9,name="Reds"))+
  labs(title="Tree Model Probabilities of Various Users Being STEM\nWith Different Age, Height, Location and Diet")

5 Conclusion

It appears that it is difficult to predict whether a user has a STEM Major. Perhaps using something like a Random Forest or KNN may yield better results, as it can model complex interactions. Either way, it appears using simple approaches does not give good results.

We also have a large class imbalance, to remedy this problem, using something like the SMOTE algorithm to create synthetic data may help.

Another approach would be to combine the two models in order to create an ensemble model. We can also add other models to get better predictions. The Logistic model was much better in finding the true positives, while the Decision Tree was poor in this regard.

It could also be the case that we just don’t have enough good, quality data to predict such a rare event. If we had other features, we could perhaps be in a better position to predict better.