Alzheimer’s disease (AD), a slowly progressive brain disease in older people characterized by gradual memory loss, cognitive deficits and behavioral changes, is the most common cause of dementia. It is estimated that in 2016 5.4 million Americans had AD, among which 96% were over the age of 65. By the year 2050, the number of Americans with the age ??? 65 affected by AD may reach 13.8 million, nearly tripling the current number, if no medical breakthroughs are developed to cure or prevent the disease. The costs of health care for individuals with AD are estimated at 159 billion dollars to 215 billion dollars per year in the United States and over $600 billion worldwide, imposing significant economic burdens on families and society.
No cure or disease modifying treatment is currently available for AD. By the time AD is clinically diagnosed, there is considerable multi-system degeneration that has occurred within the brain. Therefore, potential treatments will likely have the greatest impact when provided at the earliest possible stage of AD.
To detect AD at its early stages, I propose to develop a model that can help clinicians detect early AD using a set of clinical traits collected from patients.
Clinical data of 619 patients were retrieved from the Religious Orders Study and Memory and Aging Project (ROSMAP) Study hosted on https://www.synapse.org/#!Synapse:syn3157322. The website requires registration, to make the dataset accessible, I uploaded the dataset to my github https://github.com/grayapply2009/Alzheimers_Disease_Detect_xgboost.
Future analyses would include larger datasets with more clinical features and subjects.
. The dataset consists of 619 subjects aged 67 to 90+.
. Subjects were grouped into AD, probable AD, possible AD and Normal.
. Each subject has multiple clinical traits.
| Column Name | Full Forms |
|---|---|
| age_death | age at death |
| msex | sex |
| edu | years of education |
| apoe_genotype | APOE genotype |
| braaksc | Braak Score |
| cts_mmse30_lv | mini mental state examination |
| cogdx | cognitive diagnostic category |
| CERAD | Diagnosis with the standard of Consortium to Establish a Registry |
Analyses and visualizations are done in R.
library(data.table)
library(xgboost)
library(caret)
## Loading required package: lattice
library(mlr)
## Loading required package: ParamHelpers
##
## Attaching package: 'mlr'
## The following object is masked from 'package:caret':
##
## train
library(dplyr)
##
## Attaching package: 'dplyr'
## The following object is masked from 'package:xgboost':
##
## slice
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:data.table':
##
## between, first, last
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
Data Preprocessing
#read data
ros_meta <- fread("C:/Users/guol03/Google Drive/machine_learning_practice/r/xgboost/rosmap_meta.tsv")
#remove undesired features
ros_meta <- ros_meta[, -c(1, 2, 3, 4, 6, 7, 14, 15, 17, 19, 20, 21, 22, 24)]
#make age feature mumerical
ros_meta$age_death <- as.numeric(gsub("\\+", "", ros_meta$age_death))
#one hot encoding for categorical features
ros_feature <- predict(dummyVars(CERAD ~ ., data = ros_meta), newdata = ros_meta)
ros_target <- ros_meta[, 10]
#detect predictors with low variance
nzv <- nearZeroVar(ros_feature)
ros_feature <- ros_feature[, -nzv]
#check missing values
#plot_missing(ros_data)
#partition data
trainIndex <- createDataPartition(ros_meta$CERAD, p = .8, list = FALSE, times = 1)
ros_train_feature <- ros_feature[trainIndex, ]
ros_train_label <- ros_target[trainIndex, ][[1]]
ros_train_label_num <- as.numeric(as.factor(ros_train_label))-1
ros_test_feature <- ros_feature[-trainIndex, ]
ros_test_label <- ros_target[-trainIndex, ][[1]]
ros_test_label_num <- as.numeric(as.factor(ros_test_label))-1
Exprlore the data
diagnosis_prop_tbl <- as.data.frame(prop.table(table(ros_meta$CERAD)))
diagnosis_prop_vec <- diagnosis_prop_tbl$Freq[match(ros_meta$CERAD, diagnosis_prop_tbl$Var1)]
plot_data <- data.frame(sex = ros_meta$msex, Diagnosis = ros_meta$CERAD, Percentage = diagnosis_prop_vec, MMSE = ros_meta$cts_mmse30_lv)
plot1 <- ggplot() + geom_bar(aes(y = Percentage, x = Diagnosis, fill = sex), data = plot_data, stat="identity") + labs(x = "Diagnosis", y = "Percentage", title = "Sex and AD status", fill = "Sex")
ggsave("plot1.pdf", plot = plot1, device = "pdf", width = 8, height = 8, units = "in")
print(plot1)
plot2 <- ggplot(plot_data,aes(x=plot_data$MMSE, fill=plot_data$Diagnosis)) + geom_density(alpha=0.25) + labs(x = "MMSE Score", y = "Density", title = "Mini Mental State Examination", fill = "Diagnosis")
ggsave("plot2.pdf", plot = plot2, device = "pdf", width = 8, height = 8, units = "in")
## Warning: Removed 1 rows containing non-finite values (stat_density).
print(plot2)
## Warning: Removed 1 rows containing non-finite values (stat_density).
. Women are disproportionally affected by AD.
. The mini mental state examination is consistent with AD progression.
. Years of education are less correlated with AD progression.
. Clinical features can be used as potential predictors for AD stages in my machine learning model.
Build XGboost model with default parameters
#make xgb.DMatrix
dtrain <- xgb.DMatrix(data =ros_train_feature,label = ros_train_label_num)
dtest <- xgb.DMatrix(data = ros_test_feature,label = ros_test_label_num)
#default parameters
params <- list(booster = "gbtree", objective = "multi:softprob", "eval_metric" = "mlogloss", "num_class" = length(unique(ros_target[[1]])), eta=0.3, gamma=0, max_depth=6, min_child_weight=1, subsample=1, colsample_bytree=1)
#compute optimal nround using CV
xgbcv <- xgb.cv( params = params, data = dtrain, nrounds = 100, nfold = 5, showsd = T, stratified = T, print_every_n = 20, early_stopping_rounds = 20, maximize = F)
## [1] train-mlogloss:1.181663+0.011924 test-mlogloss:1.300559+0.017909
## Multiple eval metrics are present. Will use test_mlogloss for early stopping.
## Will train until test_mlogloss hasn't improved in 20 rounds.
##
## [21] train-mlogloss:0.388918+0.016136 test-mlogloss:1.318680+0.074746
## Stopping. Best iteration:
## [6] train-mlogloss:0.739780+0.026263 test-mlogloss:1.220178+0.061069
##first default - model training
xgb1 <- xgb.train (params = params, data = dtrain, nrounds = 7)
xgbpred <- predict(xgb1,dtest)
#evaluate the default model
test_prediction <- matrix(xgbpred, nrow = 4, ncol=length(xgbpred)/4) %>% t() %>% data.frame() %>% mutate(label = ros_test_label_num + 1, max_prob = max.col(., "last"))
confusionMatrix(factor(test_prediction$max_prob), factor(test_prediction$label), mode = "everything")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 1 2 3 4
## 1 21 3 2 14
## 2 3 19 5 8
## 3 0 1 1 3
## 4 11 9 5 17
##
## Overall Statistics
##
## Accuracy : 0.4754
## 95% CI : (0.3843, 0.5678)
## No Information Rate : 0.3443
## P-Value [Acc > NIR] : 0.00189
##
## Kappa : 0.2589
## Mcnemar's Test P-Value : 0.47118
##
## Statistics by Class:
##
## Class: 1 Class: 2 Class: 3 Class: 4
## Sensitivity 0.6000 0.5938 0.076923 0.4048
## Specificity 0.7816 0.8222 0.963303 0.6875
## Pos Pred Value 0.5250 0.5429 0.200000 0.4048
## Neg Pred Value 0.8293 0.8506 0.897436 0.6875
## Precision 0.5250 0.5429 0.200000 0.4048
## Recall 0.6000 0.5938 0.076923 0.4048
## F1 0.5600 0.5672 0.111111 0.4048
## Prevalence 0.2869 0.2623 0.106557 0.3443
## Detection Rate 0.1721 0.1557 0.008197 0.1393
## Detection Prevalence 0.3279 0.2869 0.040984 0.3443
## Balanced Accuracy 0.6908 0.7080 0.520113 0.5461
Create task for model parameter tuning
#create tasks
ros_train <- cbind(as.data.frame(ros_train_feature), as.factor(ros_train_label))
colnames(ros_train)[ncol(ros_train)] <- "Target"
ros_test <- cbind(as.data.frame(ros_test_feature), as.factor(ros_test_label))
colnames(ros_test)[ncol(ros_test)] <- "Target"
traintask <- makeClassifTask (data = ros_train,target = "Target")
testtask <- makeClassifTask (data = ros_test,target = "Target")
#do one hot encoding`<br/>
traintask <- createDummyFeatures (obj = traintask)
testtask <- createDummyFeatures (obj = testtask)
Determine tuning strategy
#create learner
lrn <- makeLearner("classif.xgboost",predict.type = "response")
lrn$par.vals <- list(objective = "multi:softprob", "eval_metric" = "mlogloss", "num_class" = length(unique(ros_target[[1]])), nrounds=100L, eta=0.1)
#set parameter space
params <- makeParamSet( makeDiscreteParam("booster",values = c("gbtree","gblinear")), makeIntegerParam("max_depth",lower = 3L,upper = 10L), makeNumericParam("min_child_weight",lower = 1L,upper = 10L), makeNumericParam("subsample",lower = 0.5,upper = 1), makeNumericParam("colsample_bytree",lower = 0.5,upper = 1))
#set resampling strategy
rdesc <- makeResampleDesc("CV",stratify = T,iters=5L)
#search strategy
ctrl <- makeTuneControlRandom(maxit = 10L)
Tune model parameters with corss validation
#set parallel backend
library(parallel)
library(parallelMap)
parallelStartSocket(cpus = detectCores())
## Starting parallelization in mode=socket with cpus=8.
#parameter tuning
mytune <- tuneParams(learner = lrn, task = traintask, resampling = rdesc, measures = acc, par.set = params, control = ctrl, show.info = T)
## [Tune] Started tuning learner classif.xgboost for parameter set:
## Type len Def Constr Req Tunable Trafo
## booster discrete - - gbtree,gblinear - TRUE -
## max_depth integer - - 3 to 10 - TRUE -
## min_child_weight numeric - - 1 to 10 - TRUE -
## subsample numeric - - 0.5 to 1 - TRUE -
## colsample_bytree numeric - - 0.5 to 1 - TRUE -
## With control class: TuneControlRandom
## Imputation value: -0
## Exporting objects to slaves for mode socket: .mlr.slave.options
## Mapping in parallel: mode = socket; cpus = 8; elements = 10.
## [Tune] Result: booster=gbtree; max_depth=5; min_child_weight=8.64; subsample=0.548; colsample_bytree=0.913 : acc.test.mean=0.5068081
mytune$y
## acc.test.mean
## 0.5068081
Predict with tuned model
#set hyperparameters
lrn_tune <- setHyperPars(lrn,par.vals = mytune$x)
#train model
xgmodel <- mlr::train(learner = lrn_tune,task = traintask)
## [1] train-mlogloss:1.344396
## [2] train-mlogloss:1.310749
## [3] train-mlogloss:1.279515
## [4] train-mlogloss:1.252414
## [5] train-mlogloss:1.227700
## [6] train-mlogloss:1.203773
## [7] train-mlogloss:1.182987
## [8] train-mlogloss:1.165217
## [9] train-mlogloss:1.149269
## [10] train-mlogloss:1.133358
## [11] train-mlogloss:1.118552
## [12] train-mlogloss:1.104089
## [13] train-mlogloss:1.090389
## [14] train-mlogloss:1.079144
## [15] train-mlogloss:1.066949
## [16] train-mlogloss:1.055926
## [17] train-mlogloss:1.046279
## [18] train-mlogloss:1.035941
## [19] train-mlogloss:1.027615
## [20] train-mlogloss:1.020324
## [21] train-mlogloss:1.012702
## [22] train-mlogloss:1.005785
## [23] train-mlogloss:0.998370
## [24] train-mlogloss:0.991844
## [25] train-mlogloss:0.987120
## [26] train-mlogloss:0.982102
## [27] train-mlogloss:0.976648
## [28] train-mlogloss:0.970637
## [29] train-mlogloss:0.966151
## [30] train-mlogloss:0.961112
## [31] train-mlogloss:0.955615
## [32] train-mlogloss:0.950786
## [33] train-mlogloss:0.946802
## [34] train-mlogloss:0.942622
## [35] train-mlogloss:0.938774
## [36] train-mlogloss:0.934465
## [37] train-mlogloss:0.931234
## [38] train-mlogloss:0.927593
## [39] train-mlogloss:0.924848
## [40] train-mlogloss:0.921765
## [41] train-mlogloss:0.919306
## [42] train-mlogloss:0.915181
## [43] train-mlogloss:0.912750
## [44] train-mlogloss:0.909850
## [45] train-mlogloss:0.906701
## [46] train-mlogloss:0.903076
## [47] train-mlogloss:0.899704
## [48] train-mlogloss:0.897200
## [49] train-mlogloss:0.894477
## [50] train-mlogloss:0.891355
## [51] train-mlogloss:0.888411
## [52] train-mlogloss:0.886324
## [53] train-mlogloss:0.883377
## [54] train-mlogloss:0.880227
## [55] train-mlogloss:0.878058
## [56] train-mlogloss:0.876262
## [57] train-mlogloss:0.874363
## [58] train-mlogloss:0.872876
## [59] train-mlogloss:0.870159
## [60] train-mlogloss:0.867173
## [61] train-mlogloss:0.864655
## [62] train-mlogloss:0.862062
## [63] train-mlogloss:0.860306
## [64] train-mlogloss:0.858993
## [65] train-mlogloss:0.857455
## [66] train-mlogloss:0.855548
## [67] train-mlogloss:0.854186
## [68] train-mlogloss:0.851394
## [69] train-mlogloss:0.849936
## [70] train-mlogloss:0.847295
## [71] train-mlogloss:0.845484
## [72] train-mlogloss:0.843668
## [73] train-mlogloss:0.842220
## [74] train-mlogloss:0.840647
## [75] train-mlogloss:0.838977
## [76] train-mlogloss:0.836756
## [77] train-mlogloss:0.835567
## [78] train-mlogloss:0.833959
## [79] train-mlogloss:0.832684
## [80] train-mlogloss:0.831155
## [81] train-mlogloss:0.829683
## [82] train-mlogloss:0.828139
## [83] train-mlogloss:0.826336
## [84] train-mlogloss:0.824842
## [85] train-mlogloss:0.823514
## [86] train-mlogloss:0.821938
## [87] train-mlogloss:0.820337
## [88] train-mlogloss:0.818688
## [89] train-mlogloss:0.816944
## [90] train-mlogloss:0.816101
## [91] train-mlogloss:0.814876
## [92] train-mlogloss:0.813665
## [93] train-mlogloss:0.811865
## [94] train-mlogloss:0.810215
## [95] train-mlogloss:0.809142
## [96] train-mlogloss:0.807376
## [97] train-mlogloss:0.806500
## [98] train-mlogloss:0.804895
## [99] train-mlogloss:0.803542
## [100] train-mlogloss:0.802099
#predict model
xgpred <- predict(xgmodel,testtask)
confusionMatrix(xgpred$data$response,xgpred$data$truth)
## Confusion Matrix and Statistics
##
## Reference
## Prediction DefiniteAD NL PossibleAD ProbableAD
## DefiniteAD 23 0 1 14
## NL 1 22 6 11
## PossibleAD 0 0 0 0
## ProbableAD 11 10 6 17
##
## Overall Statistics
##
## Accuracy : 0.5082
## 95% CI : (0.4162, 0.5998)
## No Information Rate : 0.3443
## P-Value [Acc > NIR] : 0.0001434
##
## Kappa : 0.2979
## Mcnemar's Test P-Value : 0.0253999
##
## Statistics by Class:
##
## Class: DefiniteAD Class: NL Class: PossibleAD
## Sensitivity 0.6571 0.6875 0.0000
## Specificity 0.8276 0.8000 1.0000
## Pos Pred Value 0.6053 0.5500 NaN
## Neg Pred Value 0.8571 0.8780 0.8934
## Prevalence 0.2869 0.2623 0.1066
## Detection Rate 0.1885 0.1803 0.0000
## Detection Prevalence 0.3115 0.3279 0.0000
## Balanced Accuracy 0.7424 0.7438 0.5000
## Class: ProbableAD
## Sensitivity 0.4048
## Specificity 0.6625
## Pos Pred Value 0.3864
## Neg Pred Value 0.6795
## Prevalence 0.3443
## Detection Rate 0.1393
## Detection Prevalence 0.3607
## Balanced Accuracy 0.5336
. The xgboost model with default settings gave an reference accuracy.
. The cross validation of the xgboost model gave an average accuracy.
. Compared to the default model, the final tuned xgboost model slightly improved the accuracy.
Diagnosing Alzheimer’s Disease is not an easy task. There is no single test that can determine whether a patient has the disease, especially when patients are in the early stages of AD. Instead, a variety of factors must be examined where machine learning could help pick up on patterns that otherwise would easily be missed. Spotting the first indications of AD years before any obvious symptoms come on could help pinpoint people most likely to benefit from experimental drugs and allow family members to plan for eventual care. This project proposes detecting AD at early stages using machine learning algorithms. This tool can not only help physicians detect early signs of AD and make adjustments in their care but also help pharmaceutical companies pinpoint the patients most likely to benefit from experimental drugs.
In this proposal, I explored the dataset to understand the distribution of AD stages in different age groups and sex groups and identified several clinical traits that are highly correlated with AD. I also tentatively built a machine learning model (XGboost) to predict AD stages using these clinical traits. Although the prediction accuracy of this model was maximized with the optimized parameters, it still didn’t reach the desired value due to the relatively small sample size, few features and multiple classes.
If selected as a Fellow, the future directions of this project would be:
1. Collecting a larger dataset with more clinical features and subjects.
2. Comparing my tentative XGboost model to other machine learning algorithms and establishing an optimized model that can be generalized to other diseases.