We will investigate the OKCupid Data and see if we can predict if a user is a STEM major or not, based on the data provided. First we will perform some EDA to look at the data, and then we will use the Tidymodels framework to easily create machine learning models. If you wish to view the results, use the contents tab on the left to skip to the Performance on Test Set section.
library(modeldata)
library(tidyverse)
library(skimr)
library(tidymodels)
data("okc")
skim(okc)
| Name | okc |
| Number of rows | 59855 |
| Number of columns | 6 |
| _______________________ | |
| Column type frequency: | |
| character | 2 |
| Date | 1 |
| factor | 1 |
| numeric | 2 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| diet | 24360 | 0.59 | 5 | 19 | 0 | 18 | 0 |
| location | 0 | 1.00 | 4 | 19 | 0 | 135 | 0 |
Variable type: Date
| skim_variable | n_missing | complete_rate | min | max | median | n_unique |
|---|---|---|---|---|---|---|
| date | 0 | 1 | 2011-06-27 | 2012-07-01 | 2012-06-27 | 371 |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| Class | 0 | 1 | FALSE | 2 | oth: 50316, ste: 9539 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| age | 0 | 1 | 32.34 | 9.46 | 18 | 26 | 30 | 37 | 110 | ▇▂▁▁▁ |
| height | 2 | 1 | 68.29 | 3.99 | 1 | 66 | 68 | 71 | 95 | ▁▁▁▇▁ |
okc %>%
ggplot(aes(x=age))+
geom_histogram(binwidth = 1,fill="pink")+
theme_minimal()+
labs(title="OKC Users Tends To Be Young, The Most common Age Being 26",subtitle="Age Distribution of OKC Users",x="Age",y="Count")+
geom_vline(xintercept = 26, col="red",lty=2)
okc %>%
filter(height > 30)%>%
mutate(height = height/0.39370 )%>%
ggplot(aes(x=height))+
geom_histogram(bins=30,fill="pink")+
theme_minimal()+
labs(title="The Average OKC User was 175cm tall",subtitle="Height Distribution of OKC Users",x="Height (cm)",y="Count")
okc %>%
mutate(diet = str_remove(diet,"strictly"),
diet = str_remove(diet,"mostly"),
diet = str_remove(diet, " "),
diet = str_to_title(diet),
)%>%
count(diet)%>%
arrange(desc(n))%>%
ggplot(aes(x=n,y=reorder(diet,n),fill=diet))+
geom_col(show.legend = F)+
geom_text(aes(label=n),hjust=-.1)+
theme_minimal()+
lims(x=c(0,31000))+
labs(title = "OKC Users Were Not Picky Eaters",subtitle="Total Count of Diets",x="Count",y="Diet")
okc %>%
mutate(location = fct_lump_n(location,n=10) )%>%
mutate (location = str_to_title(location))%>%
count(location) %>%
arrange(-n)%>%
ggplot(aes(x=n,y=reorder(location,n),fill=location))+
geom_col(show.legend = F)+
geom_text(aes(label=n),hjust=-.1)+
theme_minimal()+
lims(x=c(0,35000))+
labs(title = "Most OKC Users were from San Francisco",subtitle="Top 10 Most Common Locations",x="Count",y="Location")
okc %>%
count(Class)%>%
mutate(prop=n/sum(n)) %>%
ggplot(aes(x=Class,y=prop,fill=Class))+
geom_col(show.legend = FALSE)+
theme_minimal()+
scale_y_continuous(labels=scales::percent_format(),limits = c(0,1))+
geom_text(aes(label= paste(round(prop*100),"%"),vjust=-.1) )+
labs(title="STEM Students were in the Minority",subtitle="Percent of Users in STEM",x="Major",y="Percent")
Were the age demographics different for different locations?
okc %>%
mutate(location = fct_lump_n(location,n=10) )%>%
mutate (location = str_to_title(location)) %>%
ggplot(aes(x=reorder(location,age),y=age,fill=location))+
#geom_violin(show.legend = F,draw_quantiles = c(.25,.50,.75))+
geom_boxplot(show.legend = F)+
theme_minimal()+
labs(title="Age Distribution of 10 Most Populated Locations",x="Location",y="Age")
okc %>%
mutate(diet = str_remove(diet,"strictly"),
diet = str_remove(diet,"mostly"),
diet = str_remove(diet, " "),
diet = str_to_title(diet),
location = fct_lump_n(location,n=15),
location = str_to_title(location)
)%>%
group_by(location)%>%
count(diet)%>%
ungroup()%>%
ggplot(aes(x=n, y=tidytext::reorder_within(diet,n,location),fill=diet ))+
geom_col(show.legend = F)+
facet_wrap(.~location,scales = "free")+
tidytext::scale_y_reordered()+
#geom_text(aes(label=n,),hjust=.5)+
theme_minimal()+
labs(title="Most Common Diets by Top 16 Most Populated Areas",x=NULL,y=NULL)
okc %>%
filter(height > 30)%>%
na.omit(height)%>%
group_by(location)%>%
mutate( avg_height = mean(height),
avg_age = mean(age)
)%>%
ungroup()%>%
distinct(location,.keep_all = TRUE) %>%
ggplot(aes(x=avg_age,y=avg_height))+
geom_point(col="red")+
geom_text(aes(label=location))+
theme_minimal()+
labs(title="Average Age and Height of OKC Users by Location",x="Average Age",y="Average Height")
okc %>%
#filter( !is.na(diet)) %>%
mutate(diet = str_remove(diet,"strictly"),
diet = str_remove(diet,"mostly"),
diet = str_remove(diet, " "),
diet = str_to_title(diet),
location = fct_lump_n(location,n=15),
location = str_to_title(location)
) %>%
group_by(diet,Class)%>%
count(diet)%>%
ungroup(diet)%>%
mutate(prop=n/sum(n)) %>%
ggplot(aes(x=reorder(diet,-prop),y=prop,fill=Class))+
geom_col(position = position_dodge())+
geom_text(aes(label= paste0( round(prop*100,1),"%" ) ),position = position_dodge(width = 1),vjust=-.5)+
scale_y_continuous(labels = scales::percent_format())+
labs(title="Dietary Difference Between Major",subtitle = "Proportion of Diet Amongst Major",x="Diet",y="Percent",fill="Class:")+
theme_minimal()+
theme(legend.position = "top",
panel.grid.major.x = element_blank() )
ggplot(okc,aes(x=Class,y=age,fill=Class))+
geom_violin(draw_quantiles = c(.25,.5,.75),show.legend = F)+
theme_minimal()+
labs(title="Age Distribution By Major")
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
okc %>%
filter(height > 25) %>%
ggplot(aes(x=Class,y=height,fill=Class))+
geom_violin(draw_quantiles = c(.25,.5,.75),show.legend = F)+
#geom_boxplot(show.legend = F)+
theme_minimal()+
labs(title="Height Distribution By Major",y="Height (Inches)")
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
okc %>%
filter(height > 25) %>%
mutate(diet = str_remove(diet,"strictly"),
diet = str_remove(diet,"mostly"),
diet = str_remove(diet, " "),
diet = str_to_title(diet),
)%>%
ggplot(aes(x=reorder(diet,height),y=height,fill=diet))+
geom_violin( draw_quantiles = c(.25,.5,.75),show.legend = FALSE)+
theme_minimal()+
labs(title="Height Difference Across Diets",x="Diet",y="Height (Inches)")
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
okc %>%
na.omit() %>%
filter(height > 25) %>%
mutate(diet = str_remove(diet,"strictly"),
diet = str_remove(diet,"mostly"),
diet = str_remove(diet, " "),
diet = str_to_title(diet),
)%>%
ggplot(aes(x=diet,y=height,fill=interaction(diet,Class) ) ) + #Use fill=interaction(Class,diet) for better colouring - doesn't work here due to NAs
geom_violin( draw_quantiles = c(.25,.5,.75))+
theme_minimal()+
labs(title="Height Difference Majors Across Diets",subtitle ="Left is STEM, Right is Others" ,x="Diet",y="Height (Inches)")+
scale_fill_manual(values=c("darkslategray4","firebrick4","palegreen4","slateblue4","royalblue3","darkgoldenrod4",
"darkslategray1","firebrick2","palegreen1","slateblue2","royalblue1","darkgoldenrod1"))+
theme(legend.position = "right")+
guides(fill=guide_legend(ncol=2))
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
okc %>%
na.omit() %>%
filter(height > 25) %>%
mutate(diet = str_remove(diet,"strictly"),
diet = str_remove(diet,"mostly"),
diet = str_remove(diet, " "),
diet = str_to_title(diet),
)%>%
ggplot(aes(x= tidytext::reorder_within(diet,height,Class)
,y=height,fill=diet)) + #Use fill=interaction(Class,diet) for better colouring - doesn't work here due to NAs
#geom_violin( draw_quantiles = c(.25,.5,.75))+
geom_boxplot()+
facet_wrap(.~Class,nrow=2,scales = "free_x")+
theme_bw()+
labs(title="STEM Majors Who Followed A Kosher Diet Were Tallest For That Group,\nWhile Vegans Were The Shortest",x="Diet",y="Height (Inches)")+
scale_fill_manual(values=c("darkslategray4","firebrick4","palegreen4","slateblue4","royalblue3","darkgoldenrod4"))+
theme(legend.position = "none")+
tidytext::scale_x_reordered()
Seeing which combination of diet and class has the tallest individuals on average.
okc %>%
#na.omit() %>%
filter(height > 25) %>%
mutate(diet = str_remove(diet,"strictly"),
diet = str_remove(diet,"mostly"),
diet = str_remove(diet, " "),
diet = str_to_title(diet),
) %>%
group_by(Class,diet)%>%
summarise(avg_height = mean(height),.groups="keep") %>%
mutate( avg_height = avg_height/0.39370) %>%
arrange(-avg_height) %>%
ungroup() %>%
pivot_wider(names_from = Class,values_from=avg_height)%>%
mutate( avg = (stem+other)/2 )%>%
arrange(-avg)
| diet | stem | other | avg |
|---|---|---|---|
| Halal | 176.8234 | 176.7685 | 176.7960 |
| Kosher | 179.7054 | 172.9669 | 176.3362 |
| Anything | 176.8687 | 173.3305 | 175.0996 |
| NA | 176.5763 | 172.7735 | 174.6749 |
| Other | 175.9794 | 173.1469 | 174.5632 |
| Vegetarian | 176.0799 | 170.9882 | 173.5341 |
| Vegan | 173.6186 | 171.7596 | 172.6891 |
okc %>%
na.omit() %>%
filter(height > 25) %>%
mutate(diet = str_remove(diet,"strictly"),
diet = str_remove(diet,"mostly"),
diet = str_remove(diet, " "),
diet = str_to_title(diet),
)%>%
ggplot(aes(x=diet,y=age,fill=interaction(diet,Class) ) ) + #Use fill=interaction(Class,diet) for better colouring - doesn't work here due to NAs
geom_violin( draw_quantiles = c(.25,.5,.75))+
theme_minimal()+
labs(title="Age by Diet and Class",x="Diet",y="Age")+
scale_fill_manual(values=c("darkslategray4","firebrick4","palegreen4","slateblue4","royalblue3","darkgoldenrod4",
"darkslategray1","firebrick2","palegreen1","slateblue2","royalblue1","darkgoldenrod1"))+
theme(legend.position = "right")+
guides(fill=guide_legend(ncol=2))
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
Is there a difference between major in different locations?
okc %>%
#filter( !is.na(diet)) %>%
mutate(location = fct_lump_prop(location,prop=0.01),
location = str_to_title(location)
) %>%
group_by(location,Class)%>%
count(location)%>%
ungroup(Class)%>%
mutate(prop=n/sum(n)) %>%
ungroup() %>%
ggplot(aes(x= location,y=prop,fill=Class))+
geom_col(position = position_stack() )+
geom_text(aes(label= paste0( round(prop*100,1),"%" ) ), position = position_stack(vjust = 0.5) )+
scale_y_continuous(labels = scales::percent_format())+
labs(title="Major Difference Between Top 99% of Most Populated Locations",x="Location",y="Percent",fill="Class:")+
theme_minimal()+
theme(legend.position = "top",
panel.grid.major.x = element_blank() )+
coord_flip()+
guides(fill = guide_legend(reverse=TRUE))
okc %>%
#filter( !is.na(diet)) %>%
mutate(location = fct_lump_prop(location,prop=0.05),
location = str_to_title(location)
) %>%
group_by(location,Class)%>%
count(location)%>%
ungroup(Class)%>%
mutate(prop=n/sum(n)) %>%
ungroup() %>%
ggplot(aes(x= location,y=prop,fill=Class))+
geom_col(position = position_stack() )+
geom_text(aes(label= paste0( round(prop*100,1),"%" ) ), position = position_stack(vjust = 0.5) )+
scale_y_continuous(labels = scales::percent_format())+
labs(title="Major Difference Between Top 95% of Most Populated Locations",x="Location",y="Percent",fill="Class:")+
theme_minimal()+
theme(legend.position = "top",
panel.grid.major.x = element_blank() )+
coord_flip()+
guides(fill = guide_legend(reverse=TRUE))
Armed with the information gathered from the EDA process, we can create a cleaned dataset to model on.
okc_clean = okc %>%
filter( height>25) %>%
mutate(diet = str_remove(diet,"strictly"),
diet = str_remove(diet,"mostly"),
diet = str_remove(diet, " "),
diet = str_to_title(diet),
diet = factor(diet),
location = fct_lump_prop(location,prop=0.05), #Lump people who appear less than 5% as Other
location = str_to_title(location),
location = factor(location))%>%
select(-date)
skim(okc_clean)
| Name | okc_clean |
| Number of rows | 59847 |
| Number of columns | 5 |
| _______________________ | |
| Column type frequency: | |
| factor | 3 |
| numeric | 2 |
| ________________________ | |
| Group variables | None |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| diet | 24355 | 0.59 | FALSE | 6 | Any: 27841, Veg: 4976, Oth: 1785, Veg: 698 |
| location | 0 | 1.00 | FALSE | 4 | San: 31061, Oth: 17364, Oak: 7213, Ber: 4209 |
| Class | 0 | 1.00 | FALSE | 2 | oth: 50310, ste: 9537 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| age | 0 | 1 | 32.34 | 9.46 | 18 | 26 | 30 | 37 | 110 | ▇▂▁▁▁ |
| height | 0 | 1 | 68.30 | 3.94 | 26 | 66 | 68 | 71 | 95 | ▁▁▆▇▁ |
set.seed(8888)
okc_split = initial_split(okc_clean,strata = Class)
okc_train = training(okc_split)
okc_test = testing(okc_split)
okc_rec = recipe(Class~.,data=okc_train) %>%
themis::step_downsample(Class) %>%
step_unknown(diet) %>%
step_dummy(all_nominal(),-all_outcomes()) %>%
step_zv(all_predictors())
## Registered S3 methods overwritten by 'themis':
## method from
## bake.step_downsample recipes
## bake.step_upsample recipes
## prep.step_downsample recipes
## prep.step_upsample recipes
## tidy.step_downsample recipes
## tidy.step_upsample recipes
tree_rec = recipe(Class~.,data=okc_train) %>%
step_unknown(diet)
lgst_mod = logistic_reg() %>%
set_engine("glm")
tree_mod = decision_tree(cost_complexity = tune(),
min_n = tune(),
tree_depth = tune()) %>%
set_engine("rpart")%>%
set_mode("classification")
lgst_wf = workflow() %>%
add_model(lgst_mod) %>%
add_recipe(okc_rec)
tree_wf = workflow() %>%
add_model(tree_mod)%>%
add_recipe(tree_rec)
#doParallel::registerDoParallel()
start = Sys.time()
set.seed(8888)
tree_res = tree_wf %>%
tune_grid( resamples = vfold_cv(okc_train,strata = Class),
grid = grid_max_entropy(cost_complexity(), min_n(),tree_depth(),size=20 ),
control = control_grid(save_pred = TRUE))
#doParallel:::stopImplicitCluster()
end = Sys.time()
paste("Time Taken:",end-start)
## [1] "Time Taken: 1.93076254924138"
tree_res %>%
collect_metrics()
| cost_complexity | tree_depth | min_n | .metric | .estimator | mean | n | std_err | .config |
|---|---|---|---|---|---|---|---|---|
| 0.0002728 | 8 | 17 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model01 |
| 0.0002728 | 8 | 17 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model01 |
| 0.0000001 | 1 | 24 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model02 |
| 0.0000001 | 1 | 24 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model02 |
| 0.0000042 | 6 | 3 | accuracy | binary | 0.8405962 | 10 | 0.0000492 | Model03 |
| 0.0000042 | 6 | 3 | roc_auc | binary | 0.6079312 | 10 | 0.0033329 | Model03 |
| 0.0787769 | 12 | 20 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model04 |
| 0.0787769 | 12 | 20 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model04 |
| 0.0000000 | 2 | 39 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model05 |
| 0.0000000 | 2 | 39 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model05 |
| 0.0425398 | 3 | 20 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model06 |
| 0.0425398 | 3 | 20 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model06 |
| 0.0000000 | 13 | 2 | accuracy | binary | 0.8339795 | 10 | 0.0005051 | Model07 |
| 0.0000000 | 13 | 2 | roc_auc | binary | 0.6326370 | 10 | 0.0028302 | Model07 |
| 0.0000028 | 11 | 4 | accuracy | binary | 0.8379674 | 10 | 0.0001622 | Model08 |
| 0.0000028 | 11 | 4 | roc_auc | binary | 0.6370359 | 10 | 0.0037158 | Model08 |
| 0.0001168 | 13 | 39 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model09 |
| 0.0001168 | 13 | 39 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model09 |
| 0.0000150 | 8 | 34 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model10 |
| 0.0000150 | 8 | 34 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model10 |
| 0.0000000 | 14 | 33 | accuracy | binary | 0.8397497 | 10 | 0.0002963 | Model11 |
| 0.0000000 | 14 | 33 | roc_auc | binary | 0.6263418 | 10 | 0.0040858 | Model11 |
| 0.0000000 | 13 | 16 | accuracy | binary | 0.8381233 | 10 | 0.0001983 | Model12 |
| 0.0000000 | 13 | 16 | roc_auc | binary | 0.6349699 | 10 | 0.0032979 | Model12 |
| 0.0344537 | 7 | 31 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model13 |
| 0.0344537 | 7 | 31 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model13 |
| 0.0356874 | 11 | 37 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model14 |
| 0.0356874 | 11 | 37 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model14 |
| 0.0000000 | 8 | 35 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model15 |
| 0.0000000 | 8 | 35 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model15 |
| 0.0000001 | 8 | 16 | accuracy | binary | 0.8404848 | 10 | 0.0000888 | Model16 |
| 0.0000001 | 8 | 16 | roc_auc | binary | 0.6133944 | 10 | 0.0067255 | Model16 |
| 0.0185553 | 5 | 2 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model17 |
| 0.0185553 | 5 | 2 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model17 |
| 0.0000337 | 15 | 15 | accuracy | binary | 0.8380788 | 10 | 0.0002459 | Model18 |
| 0.0000337 | 15 | 15 | roc_auc | binary | 0.6258610 | 10 | 0.0048791 | Model18 |
| 0.0000000 | 1 | 4 | accuracy | binary | 0.8406407 | 10 | 0.0000232 | Model19 |
| 0.0000000 | 1 | 4 | roc_auc | binary | 0.5000000 | 10 | 0.0000000 | Model19 |
| 0.0000019 | 12 | 26 | accuracy | binary | 0.8397496 | 10 | 0.0001654 | Model20 |
| 0.0000019 | 12 | 26 | roc_auc | binary | 0.6187251 | 10 | 0.0138764 | Model20 |
tree_res %>%
collect_predictions() %>%
group_by(id) %>%
roc_curve(truth=Class,.pred_stem) %>%
ggplot(aes(x=1-sensitivity,y=specificity,col=id))+
geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
theme_minimal()+
coord_equal()
We do not see great results.
tree_res %>%
select_best(metric="roc_auc")
| cost_complexity | tree_depth | min_n | .config |
|---|---|---|---|
| 2.8e-06 | 11 | 4 | Model08 |
lgst_fit = fit(lgst_wf,data=okc_train)
coefs = lgst_fit %>%
pull_workflow_fit() %>%
tidy()
coefs
| term | estimate | std.error | statistic | p.value |
|---|---|---|---|---|
| (Intercept) | 7.1115601 | 0.3380275 | 21.0384057 | 0.0000000 |
| age | -0.0002797 | 0.0019638 | -0.1424038 | 0.8867610 |
| height | -0.1024688 | 0.0046800 | -21.8949886 | 0.0000000 |
| diet_Halal | 0.0515111 | 0.4466523 | 0.1153270 | 0.9081859 |
| diet_Kosher | 0.8325171 | 0.4272224 | 1.9486740 | 0.0513344 |
| diet_Other | 0.2191746 | 0.1094641 | 2.0022506 | 0.0452578 |
| diet_Vegan | 0.3649844 | 0.1702970 | 2.1432231 | 0.0320952 |
| diet_Vegetarian | 0.1336157 | 0.0652959 | 2.0463123 | 0.0407257 |
| diet_unknown | 0.1185543 | 0.0368114 | 3.2205865 | 0.0012793 |
| location_Oakland | 0.2397544 | 0.0868536 | 2.7604429 | 0.0057723 |
| location_Other | -0.0805340 | 0.0742326 | -1.0848876 | 0.2779714 |
| location_San.Francisco | -0.2217720 | 0.0710128 | -3.1229854 | 0.0017903 |
coefs %>%
filter(term!="(Intercept)")%>%
mutate(Significant = if_else(p.value < 0.05,"TRUE","FALSE") ) %>%
ggplot(aes(y=term,x=estimate,col=Significant))+
geom_point()+
geom_errorbar(aes(xmax=estimate+1.96*std.error,
xmin=estimate-1.96*std.error
))+
geom_vline(xintercept=0,lty=4,col="grey50")+
theme_minimal()+
labs(title="Coefficients Estimates",subtitle="Log Odds Ratio",col="Signifcant\nat 5% level")
What this is telling us, is that a unit increase in height, results in a unit increase of -0.105 for the log odds ratio of someone having a STEM major. The odds ratio is 0.9, this means that we expect to see a 10% decrease in odds ratio of someone being STEM, for a unit increase in height.
We see that Age is not a significant variable in determining the major of a person.
The baseline of the model is someone from Berkeley and has a diet of Anything. From a diet perspective, we see that someone being: Vegetarian, Vegan Unknown or Other, increases the odds of someone having a STEM Major. We do not see evidence that Kosher or Halal are different to Anything.
Location wise, Being at Oaklands, rather than Berkeley gives 17% increase in the odds of someone being STEM. Being at San Francisco is a 29% decrease in the oods of someone having a STEM Major, being from neither of these place results in a 20% decrease in the odds of someone being STEM.
coefs %>%
filter(term !="(Intercept)")%>%
mutate(Odds_Ratio = exp(estimate)) %>%
select(term,Odds_Ratio)
| term | Odds_Ratio |
|---|---|
| age | 0.9997204 |
| height | 0.9026063 |
| diet_Halal | 1.0528608 |
| diet_Kosher | 2.2990986 |
| diet_Other | 1.2450486 |
| diet_Vegan | 1.4404915 |
| diet_Vegetarian | 1.1429536 |
| diet_unknown | 1.1258680 |
| location_Oakland | 1.2709370 |
| location_Other | 0.9226235 |
| location_San.Francisco | 0.8010980 |
coefs %>%
filter(term != "(Intercept)") %>%
mutate(Significant = if_else(p.value < 0.05, "TRUE", "FALSE")) %>%
ggplot(aes(
y = term,
x = exp(estimate),
col = Significant
)) +
geom_point() +
geom_errorbar(aes(xmax = exp(estimate + 1.96 * std.error),
xmin = exp(estimate - 1.96 * std.error))
)+
geom_vline(xintercept = 1,
lty = 4,
col = "grey50") +
theme_minimal() +
labs(title = "Odds Ratio",col="Signifcant\nat 5% level")
We will use resampling to get an estimate on how the model performed on the training set.
set.seed(8888)
lgst_train_rs = lgst_wf %>%
fit_resamples(
resamples = vfold_cv(okc_train,strata = Class),
metrics = metric_set(accuracy,roc_auc, sens, spec),
control = control_resamples(save_pred = TRUE)
)
lgst_train_rs %>%
collect_metrics()
| .metric | .estimator | mean | n | std_err |
|---|---|---|---|---|
| accuracy | binary | 0.5713578 | 10 | 0.0031953 |
| roc_auc | binary | 0.6221339 | 10 | 0.0021738 |
| sens | binary | 0.6260325 | 10 | 0.0048708 |
| spec | binary | 0.5609937 | 10 | 0.0035646 |
lgst_train_rs %>%
collect_predictions()%>%
group_by(id) %>%
roc_curve(truth=Class,.pred_stem) %>%
ggplot(aes(x=1-sensitivity,y=specificity,col=id))+
geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
theme_minimal()+
coord_equal()
Below we can see the ROC Curve, we don’t get good results on the tree model
lgst_train_rs %>%
collect_predictions()%>%
group_by(id)%>%
roc_curve(truth=Class,.pred_stem) %>%
mutate(model="Logistic") %>%
bind_rows(
tree_res %>%
collect_predictions()%>%
group_by(id)%>%
roc_curve(truth=Class,.pred_stem) %>%
mutate(model="Decision Tree")
) %>%
ggplot(aes(x=1-sensitivity,y=specificity,col=model))+
geom_path(alpha = 0.7, size = 1) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
theme_minimal()+
coord_equal()
best_tree = tree_res %>% select_best("roc_auc")
final_tree_wf =
tree_wf %>%
finalize_workflow(best_tree)
final_tree_wf
## == Workflow ================================================================================================================================================================
## Preprocessor: Recipe
## Model: decision_tree()
##
## -- Preprocessor ------------------------------------------------------------------------------------------------------------------------------------------------------------
## 1 Recipe Step
##
## * step_unknown()
##
## -- Model -------------------------------------------------------------------------------------------------------------------------------------------------------------------
## Decision Tree Model Specification (classification)
##
## Main Arguments:
## cost_complexity = 2.81310986908408e-06
## tree_depth = 11
## min_n = 4
##
## Computational engine: rpart
library(vip)
final_tree = final_tree_wf %>% fit(data=okc_train)
pull_final_tree= final_tree %>%
pull_workflow_fit()
pull_final_tree %>%
vip()
It appears from the plot, that height is the most important variable for the decision tree, followed by height.
We can see the roc_auc and accuracy.
predict(lgst_fit,new_data = okc_test,type="prob") %>%
bind_cols( okc_test %>% select(Class) ) %>%
roc_auc(truth=Class, .pred_stem) %>%
mutate(model="Logistic") %>%
bind_rows(
predict(final_tree,new_data = okc_test,type="prob") %>%
bind_cols( okc_test %>% select(Class) ) %>%
roc_auc(truth=Class, .pred_stem)%>%
mutate(model="Decision Tree")
) %>%
select(model,.estimate,.metric) %>%
bind_rows(
predict(lgst_fit,new_data = okc_test) %>%
bind_cols( okc_test %>% select(Class) ) %>%
accuracy(truth=Class, .pred_class) %>%
mutate(model="Logistic") %>%
bind_rows(
predict(final_tree,new_data = okc_test) %>%
bind_cols( okc_test %>% select(Class) ) %>%
accuracy(truth=Class, .pred_class) %>%
mutate(model="Decision Tree")
) %>%
select(model,.estimate,.metric)
) %>%
pivot_wider(names_from = .metric,values_from=.estimate)
| model | roc_auc | accuracy |
|---|---|---|
| Logistic | 0.6343331 | 0.5860571 |
| Decision Tree | 0.6589016 | 0.8389145 |
In terms of roc_auc and accuracy, the decision tree is superior over the logistic model.
predict(lgst_fit,new_data = okc_test,type="prob") %>%
bind_cols( okc_test %>% select(Class) ) %>%
roc_curve(truth=Class, .pred_stem)%>%
mutate(model="Logistic") %>%
bind_rows(
predict(final_tree,new_data = okc_test,type="prob") %>%
bind_cols( okc_test %>% select(Class) ) %>%
roc_curve(truth=Class, .pred_stem)%>%
mutate(model="Decision Tree")
) %>%
ggplot(aes(x=1-specificity,y=sensitivity,col=model))+
geom_path(size=2)+
geom_abline(slope=1,col="grey50",size=1.5,lty=8)+
theme_minimal()+
coord_equal()
predict(lgst_fit,new_data = okc_test) %>%
bind_cols( okc_test %>% select(Class) ) %>%
conf_mat(truth=Class,.pred_class) %>%
autoplot(type="heatmap")
predict(final_tree,new_data = okc_test) %>%
bind_cols( okc_test %>% select(Class) ) %>%
conf_mat(truth=Class,.pred_class) %>%
autoplot(type="heatmap")
However, upon further inspection, we see that the tree is just predicting other many times. The base line is
paste0("Base Accuracy: ",round((12546+31)/nrow(okc_test) *100,1),"%")
## [1] "Base Accuracy: 84.1%"
So in-fact, we did worse than always predicting the most common occurrence in our training set, which was to predict that someone was not a STEM major.
summary(okc_train)
## age diet height location
## Min. : 18.00 Anything :20839 Min. :26.0 Berkeley : 3156
## 1st Qu.: 26.00 Halal : 58 1st Qu.:66.0 Oakland : 5391
## Median : 30.00 Kosher : 90 Median :68.0 Other :13040
## Mean : 32.36 Other : 1319 Mean :68.3 San Francisco:23299
## 3rd Qu.: 37.00 Vegan : 532 3rd Qu.:71.0
## Max. :110.00 Vegetarian: 3745 Max. :95.0
## NA's :18303
## Class
## stem : 7153
## other:37733
##
##
##
##
##
We create a fake dataset to see what the models are predicting.
fake_data_2 = expand_grid(
age = (18:70),
height = (60:80),
location = c("San Francisco", "Oakland", "Other", "Berkeley"),
diet = c("Halal", "Kosher","Anything","Other",NA, "Vegan", "Vegetarian")
)
predict(lgst_fit,new_data=fake_data_2) %>%
rename(Pred = .pred_class) %>%
bind_cols(fake_data_2) %>%
ggplot(aes(x=age,y=height,fill=Pred))+
geom_tile()+
theme_bw()+
facet_grid(diet~location)+
labs(title="Logistic Model Predictions of Various Users\nWith Different Age, Height, Location and Diet")
predict(final_tree,new_data=fake_data_2) %>%
rename(Pred = .pred_class) %>%
bind_cols(fake_data_2) %>%
ggplot(aes(x=age,y=height,fill=Pred))+
geom_tile()+
theme_bw()+
facet_grid(diet~location)+
labs(title="Decision Tree Model Predictions of Various Users\nWith Different Age, Height, Location and Diet")
lgst_cmb_preds= predict(lgst_fit,new_data = okc_test) %>%
bind_cols(okc_test) %>%
mutate( Result = case_when( .pred_class == "stem" & Class=="stem" ~ "TP",
.pred_class == "other" & Class=="other" ~ "TN",
.pred_class == "stem" & Class=="other" ~ "FP",
.pred_class == "other" & Class=="stem" ~ "FN"
)) %>%
rename(Pred = .pred_class)
ggplot()+
geom_point( data= lgst_cmb_preds%>%filter(Result %in% c("FN","FP")) ,mapping=aes(x=age,y=height),col="grey",alpha=0.8)+
geom_point( data= lgst_cmb_preds%>%filter(!Result %in% c("FN","FP")) ,mapping=aes(x=age,y=height,col=Result),alpha=0.8)+
theme_bw()+
facet_grid(diet~location)+
scale_color_manual(values = c("royalblue4","indianred4"))+
labs(title="Logistic Model with True Positives and True Negatives")
ggplot()+
geom_point( data= lgst_cmb_preds%>%filter(!Result %in% c("FN","FP")) ,mapping=aes(x=age,y=height),col="grey",alpha=0.8)+
geom_point( data= lgst_cmb_preds%>%filter(Result %in% c("FN","FP")) ,mapping=aes(x=age,y=height,col=Result),alpha=0.8)+
theme_bw()+
facet_grid(diet~location)+
scale_color_manual(values = c("royalblue1","indianred1"))+
labs(title="Logistic Model with False Positives and False Negatives")
We see that the Logistic Model harshly classifies based on height.
ggplot(okc_train,aes(x=age,y=height,col=Class))+
geom_point(alpha=0.8)+
theme_bw()+
facet_grid(diet~location)+
scale_color_manual(values = c("seagreen","mediumpurple3"))+
labs(title="OKC Training Set")
ggplot(okc_test,aes(x=age,y=height,col=Class))+
geom_point(alpha=0.8)+
theme_bw()+
facet_grid(diet~location)+
scale_color_manual(values = c("seagreen","mediumpurple3"))+
labs(title="OKC Test Set")
predict(lgst_fit,new_data=fake_data_2,type="prob") %>%
#rename(Pred = .pred_class) %>%
bind_cols(fake_data_2) %>%
ggplot(aes(x=age,y=height,fill=.pred_other))+
geom_tile()+
theme_bw()+
facet_grid(diet~location)+
scale_fill_gradientn(colours = RColorBrewer::brewer.pal(n = 11, name = 'PRGn'))+
#scale_fill_brewer(palette = RColorBrewer::brewer.pal(n = 10, name = 'PRGn'))+
labs(title="Logistic Model Probabilities of Various Users\nWith Different Age, Height, Location and Diet")
predict(final_tree,new_data=fake_data_2,type="prob") %>%
#rename(Pred = .pred_class) %>%
bind_cols(fake_data_2) %>%
ggplot(aes(x=age,y=height,fill=.pred_stem))+
geom_tile()+
theme_bw()+
facet_grid(diet~location)+
scale_fill_gradientn(colours = RColorBrewer::brewer.pal(n=9,name="Reds"))+
labs(title="Tree Model Probabilities of Various Users Being STEM\nWith Different Age, Height, Location and Diet")
It appears that it is difficult to predict whether a user has a STEM Major. Perhaps using something like a Random Forest or KNN may yield better results, as it can model complex interactions. Either way, it appears using simple approaches does not give good results.
We also have a large class imbalance, to remedy this problem, using something like the SMOTE algorithm to create synthetic data may help.
Another approach would be to combine the two models in order to create an ensemble model. We can also add other models to get better predictions. The Logistic model was much better in finding the true positives, while the Decision Tree was poor in this regard.
It could also be the case that we just don’t have enough good, quality data to predict such a rare event. If we had other features, we could perhaps be in a better position to predict better.