library(caret)
library(dplyr)
library(randomForest)
library(pROC)
library(ggpubr)
library(inspectdf)
library(ggplot2)
library(webr)
The dataset exists here https://www.kaggle.com/datasets/osmi/mental-health-in-tech-survey ## Load and Explore the data
survey <- read.table(file="survey.csv", sep = ",", header = TRUE);
survey <- survey [,1:26]
str(survey)
## 'data.frame': 1259 obs. of 26 variables:
## $ Timestamp : chr "2014-08-27 11:29:31" "2014-08-27 11:29:37" "2014-08-27 11:29:44" "2014-08-27 11:29:46" ...
## $ Age : num 37 44 32 31 31 33 35 39 42 23 ...
## $ Gender : chr "Female" "M" "Male" "Male" ...
## $ Country : chr "United States" "United States" "Canada" "United Kingdom" ...
## $ state : chr "IL" "IN" NA NA ...
## $ self_employed : chr NA NA NA NA ...
## $ family_history : chr "No" "No" "No" "Yes" ...
## $ treatment : chr "Yes" "No" "No" "Yes" ...
## $ work_interfere : chr "Often" "Rarely" "Rarely" "Often" ...
## $ no_employees : chr "6-25" "More than 1000" "6-25" "26-100" ...
## $ remote_work : chr "No" "No" "No" "No" ...
## $ tech_company : chr "Yes" "No" "Yes" "Yes" ...
## $ benefits : chr "Yes" "Don't know" "No" "No" ...
## $ care_options : chr "Not sure" "No" "No" "Yes" ...
## $ wellness_program : chr "No" "Don't know" "No" "No" ...
## $ seek_help : chr "Yes" "Don't know" "No" "No" ...
## $ anonymity : chr "Yes" "Don't know" "Don't know" "No" ...
## $ leave : chr "Somewhat easy" "Don't know" "Somewhat difficult" "Somewhat difficult" ...
## $ mental_health_consequence: chr "No" "Maybe" "No" "Yes" ...
## $ phys_health_consequence : chr "No" "No" "No" "Yes" ...
## $ coworkers : chr "Some of them" "No" "Yes" "Some of them" ...
## $ supervisor : chr "Yes" "No" "Yes" "No" ...
## $ mental_health_interview : chr "No" "No" "Yes" "Maybe" ...
## $ phys_health_interview : chr "Maybe" "No" "Yes" "Maybe" ...
## $ mental_vs_physical : chr "Yes" "Don't know" "No" "No" ...
## $ obs_consequence : chr "No" "No" "No" "Yes" ...
dim(survey)
## [1] 1259 26
Find the number of NA’s in each column
colSums(is.na(survey))
## Timestamp Age Gender
## 0 0 0
## Country state self_employed
## 0 515 18
## family_history treatment work_interfere
## 0 0 264
## no_employees remote_work tech_company
## 0 0 0
## benefits care_options wellness_program
## 0 0 0
## seek_help anonymity leave
## 0 0 0
## mental_health_consequence phys_health_consequence coworkers
## 0 0 0
## supervisor mental_health_interview phys_health_interview
## 0 0 0
## mental_vs_physical obs_consequence
## 0 0
table(survey$work_interfere)
##
## Never Often Rarely Sometimes
## 213 144 173 465
table(survey$self_employed)
##
## No Yes
## 1095 146
survey$treatment[is.na(survey$self_employed)]
## [1] "Yes" "No" "No" "Yes" "No" "No" "Yes" "No" "Yes" "No" "Yes" "No"
## [13] "Yes" "No" "No" "Yes" "Yes" "Yes"
survey$work_interfere[is.na(survey$self_employed)]
## [1] "Often" "Rarely" "Rarely" "Often" "Never" "Sometimes"
## [7] "Sometimes" "Never" "Sometimes" "Never" "Sometimes" "Never"
## [13] "Sometimes" "Never" "Never" "Rarely" "Sometimes" "Sometimes"
Replace ‘NA’ self_employed with ‘No’
survey$self_employed[is.na(survey$self_employed)] <- 'No'
table(survey$self_employed)
##
## No Yes
## 1113 146
#Plot the proportion of each class in the taget variable (treatment)
barplot(prop.table(table(survey$treatment)))
We are having a balanced data set, no need to balance it and it’s safe to use Accuracy as a measure.
Set all NA’s in work_interfere with a new category ‘Unknown’
survey$work_interfere[is.na(survey$work_interfere)] <- 'Unkown'
PD = survey %>% group_by( work_interfere, treatment) %>% count()
PieDonut(PD, aes( work_interfere, treatment , count=n), title = "Survey: Treatment By Work Infrence", explode = 3, explodeDonut=TRUE)
#Set the categories of Gender variable to M,F,T or Q
new_df <- survey
new_df$Gender[new_df$Gender == 'f' | new_df$Gender =='female' | trimws(new_df$Gender) =='Female' | new_df$Gender =='Female (cis)' | new_df$Gender =='femail' | new_df$Gender =='Cis Female' | new_df$Gender =='cis-female/femme' | tolower(new_df$Gender) =='woman'|new_df$Gender =='Femake'] <- "F"
new_df$Gender[new_df$Gender == 'm' | trimws(new_df$Gender) == 'Male' | new_df$Gender == 'male' | new_df$Gender =='Cis Man' | tolower(new_df$Gender) =='mail' | new_df$Gender =='Cis Male' | new_df$Gender =='cis male' | new_df$Gender =='Guy (-ish) ^_^' | new_df$Gender =='Male-ish' | new_df$Gender =='maile'| new_df$Gender =='something kinda male?' | tolower(new_df$Gender) =='man'| new_df$Gender == 'ostensibly male, unsure what that really means'| new_df$Gender == 'male leaning androgynous' | new_df$Gender =="Make" |new_df$Gender == "Mal"| new_df$Gender == "Male (CIS)"| new_df$Gender == "Malr"| new_df$Gender == "msle" ] <- "M"
new_df$Gender[new_df$Gender == 'All' | new_df$Gender == 'A little about you' |new_df$Gender =='Agender'| new_df$Gender == 'Androgyne' | new_df$Gender == 'fluid' | new_df$Gender == 'Enby' | new_df$Gender == 'non-binary' | new_df$Gender == 'Genderqueer' | new_df$Gender == 'queer'| new_df$Gender == 'queer/she/they'| new_df$Gender == 'Nah'| new_df$Gender == 'Neuter' | new_df$Gender == 'p'] <- 'Q'
new_df$Gender[new_df$Gender == 'Female (trans)' | new_df$Gender == "Trans woman" | new_df$Gender == "Trans-female"] <- "T"
#remove rows with age <18 or age > 100
new_df <- new_df %>%filter(Age >= 18 & Age < 100)
#remove state column
new_df = subset(new_df, select = -c(state) )
str(new_df)
## 'data.frame': 1251 obs. of 25 variables:
## $ Timestamp : chr "2014-08-27 11:29:31" "2014-08-27 11:29:37" "2014-08-27 11:29:44" "2014-08-27 11:29:46" ...
## $ Age : num 37 44 32 31 31 33 35 39 42 23 ...
## $ Gender : chr "F" "M" "M" "M" ...
## $ Country : chr "United States" "United States" "Canada" "United Kingdom" ...
## $ self_employed : chr "No" "No" "No" "No" ...
## $ family_history : chr "No" "No" "No" "Yes" ...
## $ treatment : chr "Yes" "No" "No" "Yes" ...
## $ work_interfere : chr "Often" "Rarely" "Rarely" "Often" ...
## $ no_employees : chr "6-25" "More than 1000" "6-25" "26-100" ...
## $ remote_work : chr "No" "No" "No" "No" ...
## $ tech_company : chr "Yes" "No" "Yes" "Yes" ...
## $ benefits : chr "Yes" "Don't know" "No" "No" ...
## $ care_options : chr "Not sure" "No" "No" "Yes" ...
## $ wellness_program : chr "No" "Don't know" "No" "No" ...
## $ seek_help : chr "Yes" "Don't know" "No" "No" ...
## $ anonymity : chr "Yes" "Don't know" "Don't know" "No" ...
## $ leave : chr "Somewhat easy" "Don't know" "Somewhat difficult" "Somewhat difficult" ...
## $ mental_health_consequence: chr "No" "Maybe" "No" "Yes" ...
## $ phys_health_consequence : chr "No" "No" "No" "Yes" ...
## $ coworkers : chr "Some of them" "No" "Yes" "Some of them" ...
## $ supervisor : chr "Yes" "No" "Yes" "No" ...
## $ mental_health_interview : chr "No" "No" "Yes" "Maybe" ...
## $ phys_health_interview : chr "Maybe" "No" "Yes" "Maybe" ...
## $ mental_vs_physical : chr "Yes" "Don't know" "No" "No" ...
## $ obs_consequence : chr "No" "No" "No" "Yes" ...
Inspect the categorical variables values after data cleaning.
n<- new_df %>%
inspect_cat()
n %>%
show_plot()
#age_factored <-as.factor(survey$age)
summary(new_df$Age)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 18.00 27.00 31.00 32.08 36.00 72.00
hist(as.numeric(new_df$Age), xlab = "Respondent's Age",
main = "Histogram of Age of Mental Health Survey Respondents", col = "light blue", xlim = c(0, 100))
#str(new_df)
#library(corrplot)
#library(RColorBrewer)
#corrplot(c, type="upper", order="hclust",col=brewer.pal(n=8, name="RdYlBu"))
Chi-Square test of Treatment variable with All other categorical variables.
#chisq.test(new_df$Gender, new_df$treatment)
#CHIS <- lapply(new_df[,-1], function(x) chisq.test(new_df[,1], x));
#CHIS
chi_df <- new_df[,-c(1,2)]
#chi_df<-lapply(chi_df,factor)
chi_df <- as.data.frame(unclass(chi_df),stringsAsFactors=TRUE)
#str(chi_df)
#chisq.test(as.matrix(chi_df), correct=FLASE)
chi <- as.data.frame( outer(chi_df, chi_df, Vectorize(\(x, y) chisq.test(table(x, y), sim=TRUE)$p.value)))
chi_treatment <- chi["treatment"]
ggballoonplot(chi_treatment)
dmy <- dummyVars(" ~ .", data = new_df[,-1], fullRank = T)
dat_transformed <- data.frame(predict(dmy, newdata = new_df))
dat_transformed$treatmentYes <- as.factor(dat_transformed$treatmentYes)
#glimpse(dat_transformed)
table(dat_transformed$treatment)
##
## 0 1
## 619 632
set.seed(3456)
trainIndex <- createDataPartition(dat_transformed$treatment, p = .75,
list = FALSE,
times = 1)
train <- dat_transformed[ trainIndex,]
test <- dat_transformed[-trainIndex,]
rf <- randomForest(treatmentYes~., data = train, importance = TRUE, na.action = na.roughfix
, proximity=TRUE, ntree=500 )
rf
##
## Call:
## randomForest(formula = treatmentYes ~ ., data = train, importance = TRUE, proximity = TRUE, ntree = 500, na.action = na.roughfix)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 9
##
## OOB estimate of error rate: 19.06%
## Confusion matrix:
## 0 1 class.error
## 0 342 123 0.2645161
## 1 56 418 0.1181435
plot(rf, main="Black: default, Red: samplesize, Green: tree depth")
#importance(rf, type=2)
#varImpPlot(rf)
p1 <- predict(rf,data= train)
confusionMatrix(p1, train$treatmentYes)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 342 56
## 1 123 418
##
## Accuracy : 0.8094
## 95% CI : (0.7828, 0.834)
## No Information Rate : 0.5048
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.6182
##
## Mcnemar's Test P-Value : 8.095e-07
##
## Sensitivity : 0.7355
## Specificity : 0.8819
## Pos Pred Value : 0.8593
## Neg Pred Value : 0.7726
## Prevalence : 0.4952
## Detection Rate : 0.3642
## Detection Prevalence : 0.4239
## Balanced Accuracy : 0.8087
##
## 'Positive' Class : 0
##
p2 <- predict(rf,newdata= test)
confusionMatrix(p2, test$treatmentYes)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 118 14
## 1 36 144
##
## Accuracy : 0.8397
## 95% CI : (0.7942, 0.8787)
## No Information Rate : 0.5064
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.6789
##
## Mcnemar's Test P-Value : 0.002979
##
## Sensitivity : 0.7662
## Specificity : 0.9114
## Pos Pred Value : 0.8939
## Neg Pred Value : 0.8000
## Prevalence : 0.4936
## Detection Rate : 0.3782
## Detection Prevalence : 0.4231
## Balanced Accuracy : 0.8388
##
## 'Positive' Class : 0
##
Using caret Package for CV
repeat_cv <- trainControl(method='repeatedcv',
number=10,
repeats=3)
forest <- train(
# Formula. We are using all variables to predict Species
treatmentYes~.,
# Source of data; remove the Species variable
data=train,
# `rf` method for random forest
method='rf',
# Add repeated cross validation as trControl
trControl=repeat_cv,
# Accuracy to measu
metric='Accuracy', na.action = na.roughfix)
forest$finalModel
##
## Call:
## randomForest(x = x, y = y, mtry = param$mtry)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 46
##
## OOB estimate of error rate: 19.6%
## Confusion matrix:
## 0 1 class.error
## 0 335 130 0.2795699
## 1 54 420 0.1139241
var_imp <- varImp(forest, scale=FALSE)$importance
var_imp <- data.frame(variables=row.names(var_imp), importance=var_imp$Overall)
#var_imp
arr <- var_imp %>%
## Sort the data by importance
arrange(desc(importance))
arr<- arr[1:20,]
'var_imp %>%
## Sort the data by importance'
## [1] "var_imp %>%\n \n ## Sort the data by importance"
arr %>%
## Create a ggplot object for aesthetic
ggplot(aes(x=reorder(variables, importance), y=importance)) +
## Plot the bar graph
geom_bar(stat='identity') +
## Flip the graph to make a horizontal bar plot
coord_flip() +
## Add x-axis label
xlab('Variables') +
## Add a title
labs(title='Random forest variable importance') +
## Some layout for the plot
theme_minimal() +
theme(axis.text = element_text(size = 10),
axis.title = element_text(size = 15),
plot.title = element_text(size = 20),
)
caret.predict.forest <- predict(forest, newdata=test, type="raw")
confusionMatrix(caret.predict.forest, test$treatmentYes)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 119 16
## 1 35 142
##
## Accuracy : 0.8365
## 95% CI : (0.7907, 0.8758)
## No Information Rate : 0.5064
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.6725
##
## Mcnemar's Test P-Value : 0.01172
##
## Sensitivity : 0.7727
## Specificity : 0.8987
## Pos Pred Value : 0.8815
## Neg Pred Value : 0.8023
## Prevalence : 0.4936
## Detection Rate : 0.3814
## Detection Prevalence : 0.4327
## Balanced Accuracy : 0.8357
##
## 'Positive' Class : 0
##
roc_score=roc(as.numeric(test$treatmentYes), as.numeric(caret.predict.forest)) #AUC score
plot(roc_score ,main ="ROC curve - Random Forest ")
roc_score$auc
## Area under the curve: 0.8357
'library(party)
ctr<- ctree(treatmentYes~., data = train)
plot(ctr, type="simple")'
## [1] "library(party)\nctr<- ctree(treatmentYes~., data = train)\nplot(ctr, type=\"simple\")"
plot(forest$finalModel)
trainControl <- trainControl(method="cv", number=10)
metric <- "Accuracy"
gbm.caret <- train(treatmentYes ~ .
, data=train
, distribution="bernoulli"
, method="gbm"
, trControl=trainControl
, verbose=FALSE
, metric=metric
, bag.fraction=0.75,
na.action = na.roughfix
)
gbm.caret
## Stochastic Gradient Boosting
##
## 939 samples
## 91 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 845, 845, 846, 845, 845, 845, ...
## Resampling results across tuning parameters:
##
## interaction.depth n.trees Accuracy Kappa
## 1 50 0.8253718 0.6499549
## 1 100 0.8296271 0.6584193
## 1 150 0.8285404 0.6562552
## 2 50 0.8274994 0.6541931
## 2 100 0.8296271 0.6585046
## 2 150 0.8221917 0.6437607
## 3 50 0.8317662 0.6627451
## 3 100 0.8189659 0.6371438
## 3 150 0.8200412 0.6394188
##
## Tuning parameter 'shrinkage' was held constant at a value of 0.1
##
## Tuning parameter 'n.minobsinnode' was held constant at a value of 10
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 50, interaction.depth =
## 3, shrinkage = 0.1 and n.minobsinnode = 10.
caret.predict <- predict(gbm.caret, newdata=test, type="raw")
confusionMatrix(test$treatmentYes,caret.predict )
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 120 34
## 1 13 145
##
## Accuracy : 0.8494
## 95% CI : (0.8048, 0.8872)
## No Information Rate : 0.5737
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.6981
##
## Mcnemar's Test P-Value : 0.003531
##
## Sensitivity : 0.9023
## Specificity : 0.8101
## Pos Pred Value : 0.7792
## Neg Pred Value : 0.9177
## Prevalence : 0.4263
## Detection Rate : 0.3846
## Detection Prevalence : 0.4936
## Balanced Accuracy : 0.8562
##
## 'Positive' Class : 0
##
roc_score=roc(as.numeric(test$treatmentYes), as.numeric(caret.predict)) #AUC score
plot(roc_score ,main ="ROC curve - Boosting ")
roc_score$auc
## Area under the curve: 0.8485
library(gbm)
train$treatmentYes = as.numeric(train$treatmentYes)-1 #to avoid R studio crashing
gbm_raw <- gbm(treatmentYes ~ .
, data=train
, distribution="bernoulli"
, n.trees = 500, interaction.depth = 3,
shrinkage = 0.01 , cv.folds = 10)
pref <- gbm.perf(gbm_raw, method="cv")
pref
## [1] 487
summary(gbm_raw)
## var rel.inf
## work_interfereUnkown work_interfereUnkown 32.79767718
## work_interfereSometimes work_interfereSometimes 15.86583704
## work_interfereOften work_interfereOften 12.20549216
## family_historyYes family_historyYes 11.18789685
## work_interfereRarely work_interfereRarely 8.40252035
## care_optionsYes care_optionsYes 3.65496392
## benefitsYes benefitsYes 2.83416548
## Age Age 2.49231083
## GenderM GenderM 1.52925857
## coworkersYes coworkersYes 1.14326967
## obs_consequenceYes obs_consequenceYes 0.83121912
## leaveSomewhat.easy leaveSomewhat.easy 0.79049546
## anonymityYes anonymityYes 0.57814452
## CountryUnited.Kingdom CountryUnited.Kingdom 0.42342802
## seek_helpNo seek_helpNo 0.40475211
## leaveSomewhat.difficult leaveSomewhat.difficult 0.35612007
## CountryCanada CountryCanada 0.32078312
## coworkersSome.of.them coworkersSome.of.them 0.29500858
## benefitsNo benefitsNo 0.27582416
## phys_health_interviewYes phys_health_interviewYes 0.27302447
## mental_health_consequenceNo mental_health_consequenceNo 0.27099873
## phys_health_interviewNo phys_health_interviewNo 0.22636145
## self_employedYes self_employedYes 0.20906923
## no_employees26.100 no_employees26.100 0.20578982
## supervisorYes supervisorYes 0.18217580
## remote_workYes remote_workYes 0.17466503
## no_employees6.25 no_employees6.25 0.17433900
## wellness_programYes wellness_programYes 0.16418240
## phys_health_consequenceYes phys_health_consequenceYes 0.15427112
## care_optionsNot.sure care_optionsNot.sure 0.14982455
## CountryUnited.States CountryUnited.States 0.14894658
## phys_health_consequenceNo phys_health_consequenceNo 0.14617387
## supervisorSome.of.them supervisorSome.of.them 0.14390210
## mental_vs_physicalNo mental_vs_physicalNo 0.13122291
## leaveVery.difficult leaveVery.difficult 0.12696747
## no_employees100.500 no_employees100.500 0.09340945
## leaveVery.easy leaveVery.easy 0.08535374
## mental_health_interviewNo mental_health_interviewNo 0.07934393
## seek_helpYes seek_helpYes 0.06756751
## no_employeesMore.than.1000 no_employeesMore.than.1000 0.06318142
## CountryNetherlands CountryNetherlands 0.05661222
## anonymityNo anonymityNo 0.05304576
## mental_health_consequenceYes mental_health_consequenceYes 0.04809434
## wellness_programNo wellness_programNo 0.04425046
## CountryGermany CountryGermany 0.04399747
## CountryIreland CountryIreland 0.03791535
## tech_companyYes tech_companyYes 0.03471262
## no_employees500.1000 no_employees500.1000 0.02143400
## GenderQ GenderQ 0.00000000
## GenderT GenderT 0.00000000
## CountryAustria CountryAustria 0.00000000
## CountryBelgium CountryBelgium 0.00000000
## CountryBosnia.and.Herzegovina CountryBosnia.and.Herzegovina 0.00000000
## CountryBrazil CountryBrazil 0.00000000
## CountryBulgaria CountryBulgaria 0.00000000
## CountryChina CountryChina 0.00000000
## CountryColombia CountryColombia 0.00000000
## CountryCosta.Rica CountryCosta.Rica 0.00000000
## CountryCroatia CountryCroatia 0.00000000
## CountryCzech.Republic CountryCzech.Republic 0.00000000
## CountryDenmark CountryDenmark 0.00000000
## CountryFinland CountryFinland 0.00000000
## CountryFrance CountryFrance 0.00000000
## CountryGeorgia CountryGeorgia 0.00000000
## CountryGreece CountryGreece 0.00000000
## CountryHungary CountryHungary 0.00000000
## CountryIndia CountryIndia 0.00000000
## CountryIsrael CountryIsrael 0.00000000
## CountryItaly CountryItaly 0.00000000
## CountryJapan CountryJapan 0.00000000
## CountryLatvia CountryLatvia 0.00000000
## CountryMexico CountryMexico 0.00000000
## CountryMoldova CountryMoldova 0.00000000
## CountryNew.Zealand CountryNew.Zealand 0.00000000
## CountryNigeria CountryNigeria 0.00000000
## CountryNorway CountryNorway 0.00000000
## CountryPhilippines CountryPhilippines 0.00000000
## CountryPoland CountryPoland 0.00000000
## CountryPortugal CountryPortugal 0.00000000
## CountryRomania CountryRomania 0.00000000
## CountryRussia CountryRussia 0.00000000
## CountrySingapore CountrySingapore 0.00000000
## CountrySlovenia CountrySlovenia 0.00000000
## CountrySouth.Africa CountrySouth.Africa 0.00000000
## CountrySpain CountrySpain 0.00000000
## CountrySweden CountrySweden 0.00000000
## CountrySwitzerland CountrySwitzerland 0.00000000
## CountryThailand CountryThailand 0.00000000
## CountryUruguay CountryUruguay 0.00000000
## mental_health_interviewYes mental_health_interviewYes 0.00000000
## mental_vs_physicalYes mental_vs_physicalYes 0.00000000
gbm_predict <- predict(gbm_raw, newdata=test, type="response", n.trees = pref)
gbm_y_hat <- ifelse(gbm_predict<=0.5,0,1)
gbm_test_error <- sum(gbm_y_hat!= test$treatmentYes)/ length(test)
gbm_test_error
## [1] 0.5108696
library(class)
library(tidyr)
#colSums(is.na(train))
#train_no_na <- train%>%drop_na() #knn does not accept na
#test_no_na <- test%>%drop_na()
cverror <- NULL
kk=c(1,3,5,7,9,10,13,15,17,19)
for (i in 1:length(kk)){
knn_pred<- knn(train = train, test = train, cl=train$treatmentYes, k = kk[i])
temptesterror <- mean(knn_pred != test$treatmentYes);
cverror <- c(cverror, temptesterror);
}
plot(kk, cverror)
set.seed(2)
caret.knn <- train(
# Formula. We are using all variables to predict Species
treatmentYes~.,
# Source of data; remove the Species variable
data=train,
# `rf` method for random forest
method='knn',
# Add repeated cross validation as trControl
trControl=repeat_cv,
# Accuracy to measu
#metric='Accuracy',
tuneGrid = data.frame(k = seq(11,61,by = 2)))
caret.knn
## k-Nearest Neighbors
##
## 939 samples
## 91 predictor
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 3 times)
## Summary of sample sizes: 845, 845, 845, 845, 845, 845, ...
## Resampling results across tuning parameters:
##
## k RMSE Rsquared MAE
## 11 0.4475018 0.2180191 0.4081970
## 13 0.4476526 0.2207251 0.4135165
## 15 0.4476118 0.2265300 0.4172123
## 17 0.4473285 0.2309376 0.4190578
## 19 0.4463837 0.2402427 0.4199802
## 21 0.4478170 0.2384040 0.4230252
## 23 0.4476587 0.2412297 0.4240787
## 25 0.4485159 0.2405120 0.4260708
## 27 0.4496217 0.2393253 0.4285192
## 29 0.4507151 0.2371112 0.4305997
## 31 0.4516433 0.2344778 0.4322658
## 33 0.4519193 0.2361847 0.4333541
## 35 0.4527009 0.2335586 0.4348704
## 37 0.4533933 0.2325789 0.4365067
## 39 0.4547405 0.2276118 0.4386443
## 41 0.4553201 0.2292045 0.4400489
## 43 0.4559666 0.2307663 0.4415474
## 45 0.4566267 0.2298881 0.4428045
## 47 0.4573732 0.2286691 0.4441119
## 49 0.4580448 0.2270269 0.4452354
## 51 0.4590379 0.2249895 0.4467315
## 53 0.4600605 0.2224906 0.4482620
## 55 0.4603577 0.2241141 0.4490410
## 57 0.4610760 0.2228147 0.4502313
## 59 0.4617472 0.2209428 0.4512690
## 61 0.4626133 0.2184082 0.4525564
##
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was k = 19.
knn_pred<- knn(train = train, test = test, cl=train$treatmentYes, k=19)
confusionMatrix(knn_pred, test$treatmentYes)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 137 16
## 1 17 142
##
## Accuracy : 0.8942
## 95% CI : (0.8547, 0.9261)
## No Information Rate : 0.5064
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.7884
##
## Mcnemar's Test P-Value : 1
##
## Sensitivity : 0.8896
## Specificity : 0.8987
## Pos Pred Value : 0.8954
## Neg Pred Value : 0.8931
## Prevalence : 0.4936
## Detection Rate : 0.4391
## Detection Prevalence : 0.4904
## Balanced Accuracy : 0.8942
##
## 'Positive' Class : 0
##
roc_score=roc(as.numeric(test$treatmentYes), as.numeric(knn_pred)) #AUC score
plot(roc_score ,main ="ROC curve - KNN ")
roc_score$auc
## Area under the curve: 0.8942
lr <- glm(treatmentYes ~.,family=binomial(link='logit'),data=train)
summary(lr)
##
## Call:
## glm(formula = treatmentYes ~ ., family = binomial(link = "logit"),
## data = train)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.75213 -0.34109 0.00022 0.57776 2.79739
##
## Coefficients: (3 not defined because of singularities)
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -2.76264 1.18275 -2.336 0.019503 *
## Age 0.02135 0.01538 1.388 0.165041
## GenderM -0.67641 0.28972 -2.335 0.019557 *
## GenderQ 0.72390 1.27279 0.569 0.569525
## GenderT -1.16972 1.24831 -0.937 0.348737
## CountryAustria -17.11136 1649.74553 -0.010 0.991724
## CountryBelgium 0.38887 2.97157 0.131 0.895884
## CountryBosnia.and.Herzegovina -18.77886 3956.18049 -0.005 0.996213
## CountryBrazil 4.16596 1.59138 2.618 0.008849 **
## CountryBulgaria 1.79791 1.75363 1.025 0.305245
## CountryCanada 0.07898 0.83426 0.095 0.924577
## CountryChina -16.62862 3956.18044 -0.004 0.996646
## CountryColombia -17.42439 2758.05668 -0.006 0.994959
## CountryCosta.Rica -13.19684 3956.18050 -0.003 0.997338
## CountryCroatia 15.57719 2695.04405 0.006 0.995388
## CountryCzech.Republic -19.20594 3956.18045 -0.005 0.996127
## CountryDenmark 17.16053 2797.44215 0.006 0.995106
## CountryFinland 0.92258 2.51739 0.366 0.714005
## CountryFrance 0.71754 1.54315 0.465 0.641944
## CountryGeorgia -17.58220 3956.18048 -0.004 0.996454
## CountryGermany 0.10892 0.87902 0.124 0.901383
## CountryGreece -11.86068 3956.18049 -0.003 0.997608
## CountryHungary -18.19419 3956.18047 -0.005 0.996331
## CountryIndia 0.93060 1.38374 0.673 0.501251
## CountryIreland 0.71614 0.95027 0.754 0.451080
## CountryIsrael -16.77416 1689.22666 -0.010 0.992077
## CountryItaly -14.94302 1788.70711 -0.008 0.993334
## CountryJapan 16.64156 3956.18047 0.004 0.996644
## CountryLatvia NA NA NA NA
## CountryMexico 1.93234 2.54487 0.759 0.447669
## CountryMoldova 14.33911 3956.18049 0.004 0.997108
## CountryNetherlands -0.01385 0.95424 -0.015 0.988418
## CountryNew.Zealand 0.78156 1.31865 0.593 0.553383
## CountryNigeria -12.40825 3956.18049 -0.003 0.997498
## CountryNorway -14.35557 3956.18049 -0.004 0.997105
## CountryPhilippines -17.93721 3956.18047 -0.005 0.996382
## CountryPoland 0.77376 1.42490 0.543 0.587112
## CountryPortugal -13.67333 2617.50411 -0.005 0.995832
## CountryRomania -11.76139 3956.18051 -0.003 0.997628
## CountryRussia -14.98415 2689.80001 -0.006 0.995555
## CountrySingapore -0.91932 1.49350 -0.616 0.538193
## CountrySlovenia NA NA NA NA
## CountrySouth.Africa 2.53738 2.65152 0.957 0.338590
## CountrySpain -13.08933 3956.18050 -0.003 0.997360
## CountrySweden -2.59404 1.68997 -1.535 0.124792
## CountrySwitzerland -0.17522 1.48053 -0.118 0.905790
## CountryThailand NA NA NA NA
## CountryUnited.Kingdom 0.65767 0.77438 0.849 0.395719
## CountryUnited.States 0.17551 0.74788 0.235 0.814462
## CountryUruguay -12.18779 3956.18050 -0.003 0.997542
## self_employedYes -0.04349 0.40745 -0.107 0.915003
## family_historyYes 0.88196 0.22041 4.001 6.3e-05 ***
## work_interfereOften 3.68870 0.42118 8.758 < 2e-16 ***
## work_interfereRarely 2.53587 0.34982 7.249 4.2e-13 ***
## work_interfereSometimes 3.04024 0.31646 9.607 < 2e-16 ***
## work_interfereUnkown -2.59128 0.68539 -3.781 0.000156 ***
## no_employees100.500 0.42214 0.46290 0.912 0.361801
## no_employees26.100 0.60011 0.42911 1.398 0.161963
## no_employees500.1000 0.41451 0.63124 0.657 0.511397
## no_employees6.25 0.25661 0.40092 0.640 0.522138
## no_employeesMore.than.1000 0.13951 0.47216 0.295 0.767635
## remote_workYes -0.26042 0.25330 -1.028 0.303887
## tech_companyYes -0.10942 0.29186 -0.375 0.707726
## benefitsNo 0.10714 0.32715 0.327 0.743297
## benefitsYes 0.71441 0.32882 2.173 0.029807 *
## care_optionsNot.sure -0.10237 0.28059 -0.365 0.715241
## care_optionsYes 0.70686 0.28986 2.439 0.014745 *
## wellness_programNo -0.34323 0.37554 -0.914 0.360734
## wellness_programYes -0.89674 0.44308 -2.024 0.042982 *
## seek_helpNo -0.82125 0.31139 -2.637 0.008356 **
## seek_helpYes -0.60372 0.37991 -1.589 0.112037
## anonymityNo -0.03509 0.48136 -0.073 0.941891
## anonymityYes 0.59643 0.28189 2.116 0.034357 *
## leaveSomewhat.difficult 0.33835 0.36206 0.935 0.350037
## leaveSomewhat.easy -0.36904 0.28216 -1.308 0.190908
## leaveVery.difficult 0.40612 0.44587 0.911 0.362376
## leaveVery.easy 0.33126 0.35078 0.944 0.344986
## mental_health_consequenceNo -0.24819 0.29588 -0.839 0.401570
## mental_health_consequenceYes -0.13001 0.29992 -0.433 0.664667
## phys_health_consequenceNo -0.04052 0.28332 -0.143 0.886265
## phys_health_consequenceYes 0.04239 0.53977 0.079 0.937408
## coworkersSome.of.them 0.50927 0.28986 1.757 0.078928 .
## coworkersYes 1.12503 0.42607 2.640 0.008279 **
## supervisorSome.of.them -0.48028 0.29223 -1.643 0.100282
## supervisorYes -0.30296 0.34381 -0.881 0.378214
## mental_health_interviewNo 0.18995 0.35013 0.542 0.587477
## mental_health_interviewYes 0.46867 0.68856 0.681 0.496093
## phys_health_interviewNo 0.04709 0.24412 0.193 0.847030
## phys_health_interviewYes 0.27747 0.34766 0.798 0.424813
## mental_vs_physicalNo -0.06814 0.28277 -0.241 0.809572
## mental_vs_physicalYes -0.06508 0.30097 -0.216 0.828816
## obs_consequenceYes 0.29298 0.33605 0.872 0.383295
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 1301.64 on 938 degrees of freedom
## Residual deviance: 642.97 on 850 degrees of freedom
## AIC: 820.97
##
## Number of Fisher Scoring iterations: 16
trainControl <- trainControl(method="cv", number=10, savePredictions="all", classProbs=FALSE)
metric <- "Accuracy"
glm.caret <- train(treatmentYes ~ .
, data=train,
method="glmnet",
family="binomial",
trControl=trainControl)
print(glm.caret)
## glmnet
##
## 939 samples
## 91 predictor
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 845, 846, 845, 845, 845, 845, ...
## Resampling results across tuning parameters:
##
## alpha lambda RMSE Rsquared MAE
## 0.10 0.0004925806 2.8343987 0.4064352 2.1802303
## 0.10 0.0049258058 2.4022277 0.4271535 1.8673862
## 0.10 0.0492580579 0.3671564 0.4688085 0.2986846
## 0.55 0.0004925806 2.7818445 0.4144093 2.1438790
## 0.55 0.0049258058 2.2895351 0.4415657 1.7775764
## 0.55 0.0492580579 0.3696611 0.4778800 0.3158663
## 1.00 0.0004925806 2.7506278 0.4190792 2.1224615
## 1.00 0.0049258058 2.2049428 0.4506263 1.7092291
## 1.00 0.0492580579 0.3807998 0.4565263 0.3363846
##
## RMSE was used to select the optimal model using the smallest value.
## The final values used for the model were alpha = 0.1 and lambda = 0.04925806.
lr_predict <- predict(glm.caret,newdata = test)
lr_predict_cat <- ifelse(lr_predict>=0.5,1,0)
confusionMatrix(test$treatmentYes, as.factor(lr_predict_cat))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 122 32
## 1 14 144
##
## Accuracy : 0.8526
## 95% CI : (0.8083, 0.89)
## No Information Rate : 0.5641
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.7046
##
## Mcnemar's Test P-Value : 0.01219
##
## Sensitivity : 0.8971
## Specificity : 0.8182
## Pos Pred Value : 0.7922
## Neg Pred Value : 0.9114
## Prevalence : 0.4359
## Detection Rate : 0.3910
## Detection Prevalence : 0.4936
## Balanced Accuracy : 0.8576
##
## 'Positive' Class : 0
##
roc_score=roc(as.numeric(test$treatmentYes), as.numeric(lr_predict)) #AUC score
plot(roc_score ,main ="ROC curve - Logistic Regression ")
roc_score$auc
## Area under the curve: 0.9211
mygrid <- expand.grid(.decay=c(0.5, 0.1), .size=c(3,4,5,6,8,9))
nnetFit <- train(treatmentYes ~ .
, data=train,
"nnet",
tuneGrid = mygrid,
tuneLength = 2,
trace = FALSE,
maxit = 100, trControl = trainControl(method = "cv"))
nnetFit
## Neural Network
##
## 939 samples
## 91 predictor
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 845, 845, 846, 845, 845, 845, ...
## Resampling results across tuning parameters:
##
## decay size RMSE Rsquared MAE
## 0.1 3 0.3910193 0.4065156 0.2649877
## 0.1 4 0.4006343 0.3825459 0.2699782
## 0.1 5 0.3975798 0.3901678 0.2686279
## 0.1 6 0.4048139 0.3732433 0.2693258
## 0.1 8 0.4085857 0.3694240 0.2725353
## 0.1 9 0.4059682 0.3710109 0.2749065
## 0.5 3 0.3682880 0.4670879 0.2913459
## 0.5 4 0.3673969 0.4677286 0.2886431
## 0.5 5 0.3674446 0.4677623 0.2885797
## 0.5 6 0.3672084 0.4684493 0.2883341
## 0.5 8 0.3671698 0.4679902 0.2872578
## 0.5 9 0.3669554 0.4689770 0.2869926
##
## RMSE was used to select the optimal model using the smallest value.
## The final values used for the model were size = 9 and decay = 0.5.
nn_predict <- predict(nnetFit,newdata = test)
nn_predict <- ifelse(nn_predict>=0.5,1,0)
confusionMatrix(test$treatmentYes, as.factor(nn_predict))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 122 32
## 1 15 143
##
## Accuracy : 0.8494
## 95% CI : (0.8048, 0.8872)
## No Information Rate : 0.5609
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.6982
##
## Mcnemar's Test P-Value : 0.0196
##
## Sensitivity : 0.8905
## Specificity : 0.8171
## Pos Pred Value : 0.7922
## Neg Pred Value : 0.9051
## Prevalence : 0.4391
## Detection Rate : 0.3910
## Detection Prevalence : 0.4936
## Balanced Accuracy : 0.8538
##
## 'Positive' Class : 0
##
roc_score=roc(as.numeric(test$treatmentYes), as.numeric(nn_predict)) #AUC score
plot(roc_score ,main ="ROC curve - Neural Network Regression ")
roc_score$auc
## Area under the curve: 0.8486
KNN outperformed all the models as it has the highest accuracy!