Coronary heart disease is a general disease where the arteries of the heart cannot deliver enough oxygen-rich blood to the heart. Coronary artery disease is a common symptoms, including chest pain and shortness of breath. This project is going to examine relationship between Coronary heart disease and potential factors.
The data source is from a long-term prospective study. The dataset contains 4240 observations and 16 variables. The description of the variables are as below:
sex: the gender of the observations. The variable is a binary named “male” in the dataset.
age: Age at the time of medical examination in years.
education: A categorical variable of the participants education, with the levels: Some high school (1), high school/GED (2), some college/vocational school (3), college (4)
currentSmoker: Current cigarette smoking at the time of examinations
cigsPerDay: Number of cigarettes smoked each day
BPmeds: Use of Anti-hypertensive medication at exam
prevalentStroke: Prevalent Stroke (0 = free of disease)
prevalentHyp: Prevalent Hypertensive. Subject was defined as hypertensive if treated
diabetes: Diabetic according to criteria of first exam treated
totChol: Total cholesterol (mg/dL)
sysBP: Systolic Blood Pressure (mmHg)
diaBP: Diastolic blood pressure (mmHg)
BMI: Body Mass Index, weight (kg)/height (m)^2
heartRate: Heart rate (beats/minute)
glucose: Blood glucose level (mg/dL)
TenYearCHD(response variable): The 10 year risk of coronary heart disease(CHD).
The data path is stored as below: “https://raw.githubusercontent.com/GUANTSERN-KUO/STA551/main/Week1/FraminghamHeartStudy.csv”
There are 16 variables in the heart data, including 7 binary variables, 8 numeric variables and 1 category variable (education). TenYearCHD variable is a main interesting response variable. Refer to below table, the most missing value appears in glucose variable, which contains (388 missing observations) with approximately 9% of missing proportion. Other missing values appear in variables of education, cigsPerDay, BPMeds, totChol and BMI with missing proportion no exceeds 3%.
For totChol variable, there is one potential outlier value (totChol=696). sysBP variable has an abnormal value (sysBP=295). heartRate variable has potential outlier value (heartRate=143). glucose is equal to 394. All of these value should be further checked.
summary(heatdata)
## male age education currentSmoker
## Min. :0.0000 Min. :32.00 Min. :1.000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:42.00 1st Qu.:1.000 1st Qu.:0.0000
## Median :0.0000 Median :49.00 Median :2.000 Median :0.0000
## Mean :0.4292 Mean :49.58 Mean :1.979 Mean :0.4941
## 3rd Qu.:1.0000 3rd Qu.:56.00 3rd Qu.:3.000 3rd Qu.:1.0000
## Max. :1.0000 Max. :70.00 Max. :4.000 Max. :1.0000
## NA's :105
## cigsPerDay BPMeds prevalentStroke prevalentHyp
## Min. : 0.000 Min. :0.00000 Min. :0.000000 Min. :0.0000
## 1st Qu.: 0.000 1st Qu.:0.00000 1st Qu.:0.000000 1st Qu.:0.0000
## Median : 0.000 Median :0.00000 Median :0.000000 Median :0.0000
## Mean : 9.006 Mean :0.02962 Mean :0.005896 Mean :0.3106
## 3rd Qu.:20.000 3rd Qu.:0.00000 3rd Qu.:0.000000 3rd Qu.:1.0000
## Max. :70.000 Max. :1.00000 Max. :1.000000 Max. :1.0000
## NA's :29 NA's :53
## diabetes totChol sysBP diaBP
## Min. :0.00000 Min. :107.0 Min. : 83.5 Min. : 48.0
## 1st Qu.:0.00000 1st Qu.:206.0 1st Qu.:117.0 1st Qu.: 75.0
## Median :0.00000 Median :234.0 Median :128.0 Median : 82.0
## Mean :0.02571 Mean :236.7 Mean :132.4 Mean : 82.9
## 3rd Qu.:0.00000 3rd Qu.:263.0 3rd Qu.:144.0 3rd Qu.: 90.0
## Max. :1.00000 Max. :696.0 Max. :295.0 Max. :142.5
## NA's :50
## BMI heartRate glucose TenYearCHD
## Min. :15.54 Min. : 44.00 Min. : 40.00 Min. :0.0000
## 1st Qu.:23.07 1st Qu.: 68.00 1st Qu.: 71.00 1st Qu.:0.0000
## Median :25.40 Median : 75.00 Median : 78.00 Median :0.0000
## Mean :25.80 Mean : 75.88 Mean : 81.96 Mean :0.1519
## 3rd Qu.:28.04 3rd Qu.: 83.00 3rd Qu.: 87.00 3rd Qu.:0.0000
## Max. :56.80 Max. :143.00 Max. :394.00 Max. :1.0000
## NA's :19 NA's :1 NA's :388
## glucose value proportion
prop.table(table(heatdata$glucose, exclude=NULL))
##
## 40 43 44 45 47 48
## 0.0004716981 0.0002358491 0.0004716981 0.0009433962 0.0007075472 0.0002358491
## 50 52 53 54 55 56
## 0.0007075472 0.0004716981 0.0011792453 0.0011792453 0.0030660377 0.0011792453
## 57 58 59 60 61 62
## 0.0049528302 0.0047169811 0.0025943396 0.0148584906 0.0051886792 0.0101415094
## 63 64 65 66 67 68
## 0.0148584906 0.0108490566 0.0200471698 0.0143867925 0.0252358491 0.0219339623
## 69 70 71 72 73 74
## 0.0153301887 0.0358490566 0.0181603774 0.0254716981 0.0367924528 0.0332547170
## 75 76 77 78 79 80
## 0.0455188679 0.0299528302 0.0393867925 0.0349056604 0.0226415094 0.0360849057
## 81 82 83 84 85 86
## 0.0139150943 0.0235849057 0.0356132075 0.0252358491 0.0299528302 0.0146226415
## 87 88 89 90 91 92
## 0.0268867925 0.0174528302 0.0080188679 0.0191037736 0.0068396226 0.0089622642
## 93 94 95 96 97 98
## 0.0141509434 0.0089622642 0.0113207547 0.0063679245 0.0077830189 0.0056603774
## 99 100 101 102 103 104
## 0.0049528302 0.0113207547 0.0009433962 0.0044811321 0.0075471698 0.0033018868
## 105 106 107 108 109 110
## 0.0028301887 0.0023584906 0.0035377358 0.0033018868 0.0007075472 0.0021226415
## 111 112 113 114 115 116
## 0.0004716981 0.0025943396 0.0030660377 0.0011792453 0.0035377358 0.0011792453
## 117 118 119 120 121 122
## 0.0023584906 0.0023584906 0.0004716981 0.0021226415 0.0004716981 0.0007075472
## 123 124 125 126 127 129
## 0.0011792453 0.0004716981 0.0002358491 0.0009433962 0.0009433962 0.0002358491
## 130 131 132 135 136 137
## 0.0004716981 0.0002358491 0.0007075472 0.0004716981 0.0004716981 0.0009433962
## 140 142 143 144 145 147
## 0.0009433962 0.0002358491 0.0002358491 0.0002358491 0.0004716981 0.0002358491
## 148 150 155 156 160 163
## 0.0002358491 0.0004716981 0.0002358491 0.0002358491 0.0002358491 0.0002358491
## 166 167 170 172 173 177
## 0.0002358491 0.0002358491 0.0004716981 0.0002358491 0.0004716981 0.0002358491
## 183 186 191 193 202 205
## 0.0002358491 0.0002358491 0.0002358491 0.0002358491 0.0002358491 0.0002358491
## 206 207 210 215 216 223
## 0.0004716981 0.0002358491 0.0002358491 0.0004716981 0.0002358491 0.0002358491
## 225 235 244 248 250 254
## 0.0002358491 0.0002358491 0.0002358491 0.0002358491 0.0002358491 0.0002358491
## 255 256 260 268 270 274
## 0.0002358491 0.0002358491 0.0002358491 0.0002358491 0.0002358491 0.0002358491
## 292 294 297 320 325 332
## 0.0002358491 0.0002358491 0.0002358491 0.0002358491 0.0002358491 0.0002358491
## 348 368 370 386 394 <NA>
## 0.0002358491 0.0002358491 0.0002358491 0.0002358491 0.0004716981 0.0915094340
There are several variables have missing values as indicated above. Since glucose has most of missing values, our strategy is to impute glucose variable. For those variables with missing rate below 3%, we will not perform imputation at this time.
To impute the logarithm of the glucose variable, variables with high linearly correlation will be needed.Based on the data, sysBP variable is the most linearly correlated with the glucose variable in heart data. The correlation is about 0.14. Unfortunately, 0.14 is not high enough to be linearly correlated with the glucose variable. So we decide to choose mean glucose value to impute the logarithm of the glucose missing value, rather than using linear imputation.
## cor
## 0.1405728
The missing glucose values are imputed by using the mean glucose value. Now, We are visually comparing the distributions between the original glucose and the imputed glucose.
As per below density curves, distributions of the imputed glucose and original glucose levels are close to each other.
After completing imputation, we keep the all original variables and the
new imputed variable (impute.glucose variable). Then we also delete all
missing observation records.
Exploratory Data Analysis (EDA) is an way that is used to analyze the data and find trends, patterns, or further check assumptions in data. EDA has many tools to detect classical statistics, including graphics, histograms and scatter plots, etc.
In order to further perform modeling and otput analysis, we will re-organize and covert current numeric values to character values. There are 12 numeric type of variables in heart data, these variables need to be converted into character type. One of variable - blood pressure is converted according to below reference. (https://bmcmedicine.biomedcentral.com/articles/10.1186/s12916-022-02407-z).
## age cigsPerDay totChol sysBP diaBP
## age 1.000000000 -0.19060031 0.26979872 0.38981443 0.20617145
## cigsPerDay -0.190600311 1.00000000 -0.02640377 -0.08794251 -0.05126527
## totChol 0.269798723 -0.02640377 1.00000000 0.21453561 0.17153720
## sysBP 0.389814426 -0.08794251 0.21453561 1.00000000 0.78551069
## diaBP 0.206171445 -0.05126527 0.17153720 0.78551069 1.00000000
## BMI 0.134672699 -0.08783416 0.12319550 0.32963884 0.38327303
## heartRate -0.008337091 0.06802934 0.08837479 0.18706802 0.18184920
## impute.glucose 0.113684393 -0.05146883 0.04782736 0.12975091 0.06130097
## BMI heartRate impute.glucose
## age 0.13467270 -0.008337091 0.11368439
## cigsPerDay -0.08783416 0.068029339 -0.05146883
## totChol 0.12319550 0.088374792 0.04782736
## sysBP 0.32963884 0.187068018 0.12975091
## diaBP 0.38327303 0.181849202 0.06130097
## BMI 1.00000000 0.073339162 0.07985079
## heartRate 0.07333916 1.000000000 0.09220655
## impute.glucose 0.07985079 0.092206546 1.00000000
As per above table, there is a strong positive linear relationship between sysBP and diaBP. Despite relationship between sysBP and diaBP, there is no strong linear relationship in between other variables.
## write heartdata02 csv file from orginal heatdata02 data.
## It is more efficient to read data from section 5 every time.
write.csv(x = heatdata02, file = "heartdata02.csv", row.names = FALSE)
rm(list=ls())
Before building neural network models, we should convert variable to numeric type and add required dummy variables. In addition, numeric type variables are supposed to be converted into scaling, which has no unit of measurement.
Below is scaling formula:
\[ scaled.var = \frac{orig.var - \min(orig.var)}{\max(orig.var)-\min(orig.var)} \]
heartdata02 = read.csv("https://raw.githubusercontent.com/GUANTSERN-KUO/STA551/main/Week1/heartdata02.csv")
## only choose character variable and certain numeric variable(cigsPerDay,totChol,BMI and heartRate ).
heartdata03 <- (heartdata02[, c(
"grp.male", "grp.age", "grp.education", "grp.currentSmoker"
, "cigsPerDay" , "grp.BPMeds" ,"grp.prevalentStroke",
"grp.prevalentHyp" , "grp.diabetes", "totChol","grp.sysBP" , "grp.diaBP" ,"BMI","heartRate" ,"grp.impute.glucose" , "grp.TenYearCHD"
)])
heartdata03$cigsPerDay <- as.numeric(heartdata03$cigsPerDay)
heartdata03$totChol <- as.numeric(heartdata03$totChol)
heartdata03$heartRate <- as.numeric(heartdata03$heartRate)
heartdata03$BMI = (heartdata03$BMI-min(heartdata03$BMI))/(max(heartdata03$BMI)-min(heartdata03$BMI))
heartdata03$cigsPerDay = (heartdata03$cigsPerDay-min(heartdata03$cigsPerDay))/(max(heartdata03$cigsPerDay)-min(heartdata03$cigsPerDay))
heartdata03$totChol = (heartdata03$totChol-min(heartdata03$totChol))/(max(heartdata03$totChol)-min(heartdata03$totChol))
heartdata03$heartRate = (heartdata03$heartRate-min(heartdata03$heartRate))/(max(heartdata03$heartRate)-min(heartdata03$heartRate))
We use the R function called model.matrix to present required variable and dummy variable. It is automatic output to show factors to a set of dummy variables and expanding interactions similarly. Moreover, we are renaming the variables to avoid some naming issue from neural network models.
heartdataMtx = model.matrix(~ ., data = heartdata03)
colnames(heartdataMtx)
## [1] "(Intercept)"
## [2] "grp.maleYES"
## [3] "grp.age[45, 64]"
## [4] "grp.age65+"
## [5] "grp.educationhigh school/GED"
## [6] "grp.educationsome college/vocational school"
## [7] "grp.educationSome high school"
## [8] "grp.currentSmokerYES"
## [9] "cigsPerDay"
## [10] "grp.BPMedsYES"
## [11] "grp.prevalentStrokeYES"
## [12] "grp.prevalentHypYES"
## [13] "grp.diabetesYES"
## [14] "totChol"
## [15] "grp.sysBP[110,120]"
## [16] "grp.sysBP> 120"
## [17] "grp.diaBP[80,90]"
## [18] "grp.diaBP> 90"
## [19] "BMI"
## [20] "heartRate"
## [21] "grp.impute.glucose[117,137]"
## [22] "grp.impute.glucose> 137"
## [23] "grp.TenYearCHDYES"
colnames(heartdataMtx)[2] <- "maleYES"
colnames(heartdataMtx)[3] <- "age45To64"
colnames(heartdataMtx)[4] <- "age65older"
colnames(heartdataMtx)[5] <- "highschoolGED"
colnames(heartdataMtx)[6] <- "college"
colnames(heartdataMtx)[7] <- "highschool"
colnames(heartdataMtx)[8] <- "SmokerYES"
colnames(heartdataMtx)[9] <- "cigsPerDay"
colnames(heartdataMtx)[10] <- "BPMedsYES"
colnames(heartdataMtx)[11] <- "StrokeYES"
colnames(heartdataMtx)[12] <- "HypYES"
colnames(heartdataMtx)[13] <- "diabetesYES"
colnames(heartdataMtx)[14] <- "totChol"
colnames(heartdataMtx)[15] <- "sysBP110To120"
colnames(heartdataMtx)[16] <- "sysBPgt120"
colnames(heartdataMtx)[17] <- "diaBP80To90"
colnames(heartdataMtx)[18] <- "diaBPgt90"
colnames(heartdataMtx)[19] <- "BMI"
colnames(heartdataMtx)[20] <- "heartRate"
colnames(heartdataMtx)[21] <- "glucose117To137"
colnames(heartdataMtx)[22] <- "glucosegt137"
colnames(heartdataMtx)[23] <- "TenYearCHDYES"
Please refer to below model formula.
columnNames = colnames(heartdataMtx)
#remove intercept
columnList = paste(columnNames[-c(1,length(columnNames))], collapse = "+")
#TenYearCHDYES for response variable
columnList = paste(c(columnNames[length(columnNames)],"~",columnList), collapse="")
modelFormula = formula(columnList)
modelFormula
## TenYearCHDYES ~ maleYES + age45To64 + age65older + highschoolGED +
## college + highschool + SmokerYES + cigsPerDay + BPMedsYES +
## StrokeYES + HypYES + diabetesYES + totChol + sysBP110To120 +
## sysBPgt120 + diaBP80To90 + diaBPgt90 + BMI + heartRate +
## glucose117To137 + glucosegt137
For Random splitting, 40% for testing (testDat) and 60% for training (trainDat).In addition, building the NN Model by using the function neuralnet().
n = dim(heartdataMtx)[1]
testID = sample(1:n, round(n*0.6), replace = FALSE)
testDat = heartdataMtx[-testID,]
trainDat = heartdataMtx[testID,]
NetworkModel = neuralnet(modelFormula,
data = trainDat,
hidden = 1, # single layer NN
rep = 1, # number of replicates in training NN
threshold = 0.02, # threshold for the partial derivatives as stopping criteria.
learningrate = 0.2, # user selected rate
algorithm = "rprop+"
)
kable(NetworkModel$result.matrix)
| error | 130.5869053 |
| reached.threshold | 0.0195198 |
| steps | 9074.0000000 |
| Intercept.to.1layhid1 | 4.6132312 |
| maleYES.to.1layhid1 | -0.3733819 |
| age45To64.to.1layhid1 | -2.1942890 |
| age65older.to.1layhid1 | -3.0176613 |
| highschoolGED.to.1layhid1 | 0.5938626 |
| college.to.1layhid1 | 0.4876914 |
| highschool.to.1layhid1 | 0.2826368 |
| SmokerYES.to.1layhid1 | 0.1645475 |
| cigsPerDay.to.1layhid1 | -1.8383077 |
| BPMedsYES.to.1layhid1 | -0.6802983 |
| StrokeYES.to.1layhid1 | -0.7181793 |
| HypYES.to.1layhid1 | -0.9322752 |
| diabetesYES.to.1layhid1 | 0.6886669 |
| totChol.to.1layhid1 | -0.7863313 |
| sysBP110To120.to.1layhid1 | 0.3609910 |
| sysBPgt120.to.1layhid1 | -0.5093815 |
| diaBP80To90.to.1layhid1 | 0.6299655 |
| diaBPgt90.to.1layhid1 | 0.1619388 |
| BMI.to.1layhid1 | 1.1456345 |
| heartRate.to.1layhid1 | -0.0014930 |
| glucose117To137.to.1layhid1 | -0.8040258 |
| glucosegt137.to.1layhid1 | -1.9471163 |
| Intercept.to.TenYearCHDYES | 1.1279462 |
| 1layhid1.to.TenYearCHDYES | -1.0888269 |
plot(NetworkModel, rep="best")
Single-layer backpropagation Neural network model for heart disease
Cross-validation is a method that we evaluate a machine learning model and test its performance. The training set is split into 5 smaller sets. Any of one of 5 smaller set is considered as validation data, and rest of 4 sets are considered as training data.
Below two candidate models are built - Training model (from train.data) and validation data (from valid.data). Predict score will be generated by the train.data and valid.data. A cut off point 20 will be used to compared with the predict score(pred.nn.score). If predict score (pred.nn.score) is larger than the cut 0ff point, predict status(pred.status) will be presented as 1. Predict status(pred.status) will be compared with TenYearCHDYES from the validation data (from valid.data) to further support accuracy calculation.
Refer to the figure output (5-fold CV performance), the highest accuracy occurs at cut off score of 0.57.
n0 = dim(trainDat)[1]/5
cut.off.score = seq(0,1, length = 22)[-c(1,22)] # candidate cut off prob
pred.accuracy = matrix(0,ncol=20, nrow=5, byrow = T)
# null vector for storing prediction accuracy
##
for (i in 1:5){
valid.id = ((i-1)*n0 + 1):(i*n0)
valid.data = trainDat[valid.id,]
train.data = trainDat[-valid.id,]
####
train.model = neuralnet(modelFormula,
data = train.data,
hidden = 1, # single layer NN
rep = 1,
# number of replicates in training NN
threshold = 0.02,
# threshold for the partial derivatives as stopping criteria.
learningrate = 0.2, # user selected rate
algorithm = "rprop+"
)
pred.nn.score = predict(train.model, valid.data)
for(j in 1:20){
#pred.status = rep(0,length(pred.nn.score))
pred.status = as.numeric(pred.nn.score > cut.off.score[j])
a11 = sum(pred.status == valid.data[,23])
pred.accuracy[i,j] = a11/length(pred.nn.score)
}
}
###
avg.accuracy = apply(pred.accuracy, 2, mean)
max.id = which(avg.accuracy ==max(avg.accuracy ))
### visual representation
tick.label = as.character(round(cut.off.score,2))
plot(1:20, avg.accuracy, type = "b",
xlim=c(1,20),
ylim=c(0.5,1),
axes = FALSE,
xlab = "Cut-off Score",
ylab = "Accuracy",
main = "5-fold CV performance"
)
axis(1, at=1:20, label = tick.label, las = 2)
axis(2)
segments(max.id, 0.5, max.id, avg.accuracy[max.id], col = "red")
text(max.id, avg.accuracy[max.id]+0.03, as.character(round(avg.accuracy[max.id],4)), col = "red", cex = 0.8)
Below two models - NetworkModel (from trainDat) and testing data (testDat) will be used with the highest cut off score 0.57 to find the accuracy. The accuracy is 0.8512949.
#Test the resulting output
nn.results <- predict(NetworkModel, testDat)
results <- data.frame(actual = testDat[,23], prediction = nn.results > .57)
confMatrix = table(results$prediction, results$actual) # confusion matrix
accuracy=sum(results$actua == results$prediction)/length(results$prediction)
list(confusion.matrix = confMatrix, accuracy = accuracy)
## $confusion.matrix
##
## 0 1
## FALSE 1335 242
## TRUE 9 10
##
## $accuracy
## [1] 0.8427318
The more that the ROC curve hugs the top left corner of the plot, the better the model does at classifying the data into categories. .
Sensitivity (True Positive Rate, Recall) is defined as the probability of those who received a positive result on this test out of those who actually have a 10 year risk of coronary heart disease.
Specificity (True Negative Rate) is defined as the probability of those who received a negative result on this test out of those who do not actually have the 10 year risk of coronary heart disease. Refer to below ROC Curve output, X-axis is 1 - specificity, y-axis is sensitivity.
Below two models - NetworkModel and training data (trainDat) will be used to predict the result.
Below code detail explanation
definition of “a”: Real status is getting the 10 year risk of coronary heart disease and predicting result is the same as real status.
definition of “d”: Real status is not getting the 10 year risk of coronary heart disease and predicting result is the same as real status.
definition of “b”: Real status is not getting the 10 year risk of coronary heart disease and predicting result is not the same as real status.
definition of “c”: Real status is getting the 10 year risk of coronary heart disease and predicting result is not the same as real status.
nn.results = predict(NetworkModel, trainDat) # Keep in mind that trainDat is a matrix!
cut0 = seq(0,1, length = 20)
SenSpe = matrix(0, ncol = length(cut0), nrow = 2, byrow = FALSE)
for (i in 1:length(cut0)){
a = sum(trainDat[,"TenYearCHDYES"] == 1 & (nn.results > cut0[i]))
d = sum(trainDat[,"TenYearCHDYES"] == 0 & (nn.results < cut0[i]))
b = sum(trainDat[,"TenYearCHDYES"] == 0 & (nn.results > cut0[i]))
c = sum(trainDat[,"TenYearCHDYES"] == 1 & (nn.results < cut0[i]))
sen = a/(a + c)
spe = d/(b + d)
SenSpe[,i] = c(sen, spe)
}
# plotting ROC
plot(1-SenSpe[2,], SenSpe[1,], type ="b", xlim=c(0,1), ylim=c(0,1),
xlab = "1 - specificity", ylab = "Sensitivity", lty = 1,
main = "ROC Curve", col = "blue")
abline(0,1, lty = 2, col = "red")
## Calculate AUC
xx = 1-SenSpe[2,]
yy = SenSpe[1,]
width = xx[-length(xx)] - xx[-1]
height = yy[-1]
## A better approx of ROC, need library {pROC}
prediction = as.vector(nn.results)
category = trainDat[,"TenYearCHDYES"] == 1
ROCobj <- roc(category, prediction)
AUC = auc(ROCobj)[1]
##
###
text(0.8, 0.3, paste("AUC = ", round(AUC,4)), col = "purple", cex = 0.9)
legend("bottomright", c("ROC of the model", "Random guessing"), lty=c(1,2),
col = c("blue", "red"), bty = "n", cex = 0.8)
Figure ROC Curve of the neural network model.
NN.model <- list(SenSpe= SenSpe, AUC = round(AUC,3))
Refer to above ROC curve, it is clear that the neural network is better than the random guessing as the area under the curve is larger than 0.5. The better model the closer to 1 AUC result. In general, if the area under the ROC curve is greater than 0.65, the predictive power of the underlying model is acceptable.
knitr:: include_graphics("tree.jpg")
Decision tree algorithms is easy to interpret and implement in real wold. It is utilized for both regression tasks and classification. It consists of a root node, branches, internal nodes and leaf nodes. This project aims to test the relationship of getting 10 year risk of coronary heart disease(CHD) and potential factors (e.g. age, education…ect.). The result of getting 10 year risk of coronary heart disease are presented as either Yes or No. Based on type of the result, Classification and Regression Tree (CART) will be used as the most appropriate method to predict the result.
The following diagram indicates the basic structure of a decision tree.
knitr:: include_graphics("node.png")
Decision tree growing is defined as an iterative process of splitting the main space into multiple sub-spaces based on certain criteria. Certain criteria is defined based on feature variables.
The two most frequently used impurity measures in decision tree induction are Gini index and entropy.This project aims to demonstrate the process of Gini Index and entropy in decision tree model.
Gini Index determines how the features of a dataset should split nodes to form the tree.
Entropy calculates the typical amount of data required to categorize an example.
Information Gain indicates the reduction in uncertainty attained by splitting based on a feature.
heartdata02 = read.csv("https://raw.githubusercontent.com/GUANTSERN-KUO/STA551/main/Week1/heartdata02.csv")
## only choose character variable and certain numeric variable(cigsPerDay,totChol,BMI and heartRate ).
## Same as section 5 but no Numeric Feature Scaling
heartdata04 <- (heartdata02[, c(
"grp.male", "grp.age", "grp.education", "grp.currentSmoker"
, "cigsPerDay" , "grp.BPMeds" ,"grp.prevalentStroke",
"grp.prevalentHyp" , "grp.diabetes", "totChol","grp.sysBP" , "grp.diaBP" ,"BMI","heartRate" ,"grp.impute.glucose" , "grp.TenYearCHD"
)])
# We use a random split approach
n = dim(heartdata04)[1] # sample size
# caution: using without replacement
train.id = sample(1:n, round(0.7*n), replace = FALSE)
train = heartdata04[train.id, ] # training data
test = heartdata04[-train.id, ] # testing data
# arguments to pass into rpart():
# 1. data set (training /testing);
# 2. Penalty coefficients
# 3. Impurity measure
##
tree.builder = function(in.data, fp, fn, purity){
tree = rpart(grp.TenYearCHD ~ ., # including all features
data = in.data,
na.action = na.rpart, # By default, deleted if the outcome is missing,
# kept if predictors are missing
method = "class", # Classification form factor
model = FALSE,
x = FALSE,
y = TRUE,
parms = list
( # loss matrix. Penalize false positives or negatives more heavily
loss = matrix(c(0, fp, fn, 0), ncol = 2, byrow = TRUE),
split = purity), # "gini" or "information"
## rpart algorithm options (These are defaults)
control = rpart.control(
minsplit = 10, # minimum number of observations required before split
minbucket= 10, # minimum number of observations in any terminal node, default = minsplit/3
cp = 0.01, # complexity parameter for stopping rule, 0.02 -> small tree
xval = 10 # number of cross-validation )
)
)
}
Four different decision tree models are defined based on above functions.
Model 1: gini.tree.11 is based on the Gini index without penalizing false positives and false negatives.
Model 2: info.tree.11 is based on entropy without penalizing false positives and false negatives.
Model 3: gini.tree.110 is based on the Gini index: cost of false negatives is 10 times the positives.
Model 4: info.tree.110 is based on entropy: cost of false negatives is 10 times the positives.
## Call the tree model wrapper.
gini.tree.1.1 = tree.builder(in.data = train, fp = 1, fn = 1, purity = "gini")
info.tree.1.1 = tree.builder(in.data = train, fp = 1, fn = 1, purity = "information")
gini.tree.1.10 = tree.builder(in.data = train, fp = 1, fn = 10, purity = "gini")
info.tree.1.10 = tree.builder(in.data = train, fp = 1, fn = 10, purity = "information")
## tree plots
par(mfrow=c(1,2))
rpart.plot(gini.tree.1.1, main = "Tree with Gini index: non-penalization")
## Warning: Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
## To silence this warning:
## Call rpart.plot with roundint=FALSE,
## or rebuild the rpart model with model=TRUE.
rpart.plot(info.tree.1.1, main = "Tree with entropy: non-penalization")
## Warning: Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
## To silence this warning:
## Call rpart.plot with roundint=FALSE,
## or rebuild the rpart model with model=TRUE.
Figure Non-penalized decision tree models using Gini index (left) and entropy (right).
par(mfrow=c(1,2))
rpart.plot(gini.tree.1.10, main = "Tree with Gini index: penalization")
## Warning: Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
## To silence this warning:
## Call rpart.plot with roundint=FALSE,
## or rebuild the rpart model with model=TRUE.
rpart.plot(info.tree.1.10, main = "Tree with entropy: penalization")
## Warning: Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
## To silence this warning:
## Call rpart.plot with roundint=FALSE,
## or rebuild the rpart model with model=TRUE.
Figure penalized decision tree models using Gini index (left) and entropy (right).
The ROC analysis is used to select the best model from 4 different decision tree models created previously.
# function returning a sensitivity and specificity matrix
SenSpe = function(in.data, fp, fn, purity){
cutoff = seq(0,1, length = 20) # 20 cut-offs including 0 and 1.
model = tree.builder(in.data, fp, fn, purity)
## Caution: decision tree returns both "success" and "failure" probabilities.
## We need only "success" probability to define sensitivity and specificity!!!
pred = predict(model, newdata = in.data, type = "prob") # two-column matrix.
senspe.mtx = matrix(0, ncol = length(cutoff), nrow= 2, byrow = FALSE)
for (i in 1:length(cutoff)){
pred.out = ifelse(pred[,"YES"] >= cutoff[i], "YES", "NO")
TP = sum(pred.out =="YES" & in.data$grp.TenYearCHD == "YES")
TN = sum(pred.out =="NO" & in.data$grp.TenYearCHD == "NO")
FP = sum(pred.out =="YES" & in.data$grp.TenYearCHD == "NO")
FN = sum(pred.out =="NO" & in.data$grp.TenYearCHD == "YES")
senspe.mtx[1,i] = TP/(TP + FN)
senspe.mtx[2,i] = TN/(TN + FP)
}
## A better approx of ROC, need library {pROC}
prediction = pred[, "YES"]
category = in.data$grp.TenYearCHD == "YES"
ROCobj <- roc(category, prediction)
AUC = auc(ROCobj)
##
list(senspe.mtx= senspe.mtx, AUC = round(AUC,3))
}
Next we use above function (SenSpe) to build 4 trees models.
giniROC11 = SenSpe(in.data = train, fp=1, fn=1, purity="gini")
infoROC11 = SenSpe(in.data = train, fp=1, fn=1, purity="information")
giniROC110 = SenSpe(in.data = train, fp=1, fn=10, purity="gini")
infoROC110 = SenSpe(in.data = train, fp=1, fn=10, purity="information")
Next, we draw the ROC curves and calculate the areas under the ROC curves.
par(pty="s") # set up square plot through graphic parameter
plot(1-giniROC11$senspe.mtx[2,], giniROC11$senspe.mtx[1,], type = "l", xlim=c(0,1), ylim=c(0,1),
xlab="1 - specificity", ylab="Sensitivity", col = "blue", lwd = 2,
main="ROC Curves of Decision Trees", cex.main = 0.9, col.main = "navy")
abline(0,1, lty = 2, col = "orchid4", lwd = 2)
lines(1-infoROC11$senspe.mtx[2,], infoROC11$senspe.mtx[1,], col = "firebrick2", lwd = 2, lty=2)
lines(1-giniROC110$senspe.mtx[2,], giniROC110$senspe.mtx[1,], col = "olivedrab", lwd = 2)
lines(1-infoROC110$senspe.mtx[2,], infoROC110$senspe.mtx[1,], col = "skyblue", lwd = 2)
legend("bottomright", c(paste("gini.1.1, AUC =", giniROC11$AUC),
paste("info.1.1, AUC =",infoROC11$AUC),
paste("gini.1.10, AUC =",giniROC110$AUC),
paste("info.1.10, AUC =",infoROC110$AUC)),
col=c("blue","firebrick2","olivedrab","skyblue"),
lty=rep(1,6), lwd=rep(2,6), cex = 0.5, bty = "n")
Figure Comparison of ROC curves
The above ROC curves are plotted based on the 4 tree models. gini.1.10 model has the highest AUC as the best model within the 4 tree models.
The final model (gini.1.10) is determined above, we are going to find the optimal cut-off score for reporting the predictive performance of the final model with the test data.
Optm.cutoff = function(in.data, fp, fn, purity){
n0 = dim(in.data)[1]/5
cutoff = seq(0,1, length = 20) # candidate cut off prob
## accuracy for each candidate cut-off
accuracy.mtx = matrix(0, ncol=20, nrow=5) # 20 candidate cutoffs and gini.11
##
for (k in 1:5){
valid.id = ((k-1)*n0 + 1):(k*n0)
valid.dat = in.data[valid.id,]
train.dat = in.data[-valid.id,]
## tree model
tree.model = tree.builder(in.data, fp, fn, purity)
## prediction
pred = predict(tree.model, newdata = valid.dat, type = "prob")[,2]
## for-loop
for (i in 1:20){
## predicted probabilities
pc.1 = ifelse(pred > cutoff[i], "YES", "NO")
## accuracy
a1 = mean(pc.1 == valid.dat$grp.TenYearCHD)
accuracy.mtx[k,i] = a1
}
}
avg.acc = apply(accuracy.mtx, 2, mean)
## plots
n = length(avg.acc)
idx = which(avg.acc == max(avg.acc))
tick.label = as.character(round(cutoff,2))
##
plot(1:n, avg.acc, xlab="cut-off score", ylab="average accuracy",
ylim=c(min(avg.acc), 1),
axes = FALSE,
main=paste("5-fold CV optimal cut-off \n ",purity,"(fp, fn) = (", fp, ",", fn,")" , collapse = ""),
cex.main = 0.9,
col.main = "navy")
axis(1, at=1:20, label = tick.label, las = 2)
axis(2)
points(idx, avg.acc[idx], pch=19, col = "red")
segments(idx , min(avg.acc), idx , avg.acc[idx ], col = "red")
text(idx, avg.acc[idx]+0.03, as.character(round(avg.acc[idx],4)), col = "red", cex = 0.8)
}
Below figure indicates the optimal cut-off score of 4 decision trees.
gini.1.1: 0.53 gini.1.10: 0.68 info.1.1: 0.58 info.1.10: 0.68
par(mfrow=c(3,2))
Optm.cutoff(in.data = train, fp=1, fn=1, purity="gini")
Optm.cutoff(in.data = train, fp=1, fn=1, purity="information")
Optm.cutoff(in.data = train, fp=1, fn=10, purity="gini")
Optm.cutoff(in.data = train, fp=1, fn=10, purity="information")
Figure: Plot of optimal cut-off determination
Neural Network model and gini1.10 model (the best decision tree model) are compared in the below figure. The NN model (Neural Network) is the better model as it has the higher AUC.
par(pty="s") # set up square plot through graphic parameter
plot(1-giniROC110$senspe.mtx[2,], giniROC110$senspe.mtx[1,], type = "l", xlim=c(0,1), ylim=c(0,1),
xlab="1 - specificity", ylab="Sensitivity", col = "blue", lwd = 2,
main="ROC Curves based on different model", cex.main = 0.9, col.main = "navy")
abline(0,1, lty = 2, col = "orchid4", lwd = 2)
lines(1-NN.model$SenSpe[2,], NN.model$SenSpe[1,], col = "red", lwd = 2, lty=2)
legend("bottomright", c(paste("gini.1.10, AUC =",giniROC110$AUC),
paste("NN.model, AUC =",NN.model$AUC)),
col=c("blue","red"),
lty=rep(1,6), lwd=rep(2,6), cex = 0.5, bty = "n")
Figure Comparison of ROC curves