| title: “Wine Quality Prediction” |
| author: “Xinyue_XU” |
| date: “2023-02-18” |
| output: |
| html_document: default |
| pdf_document: default |
Upload the dataset
# remove all the objects currently stored in the workspace.This can be useful when you want to start with a fresh workspace or clear up memory space
rm(list=ls())
wine <- read.csv('winequality-red.csv', sep=';', header = TRUE)
Explore the dataset
# the dataset contains 1599 observations of 12 variables. 'dim()'function is a useful way to quickly determine the size of the data frame
dim(wine)
## [1] 1599 12
library(psych)
library('dplyr')
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
# 'describe()'function shows summary statistics of each variable, such as number of non-missing observations, mean, min, standard deviation.
describe(wine)%>%select(c(1:3,8:9))
## vars n mean min max
## fixed.acidity 1 1599 8.32 4.60 15.90
## volatile.acidity 2 1599 0.53 0.12 1.58
## citric.acid 3 1599 0.27 0.00 1.00
## residual.sugar 4 1599 2.54 0.90 15.50
## chlorides 5 1599 0.09 0.01 0.61
## free.sulfur.dioxide 6 1599 15.87 1.00 72.00
## total.sulfur.dioxide 7 1599 46.47 6.00 289.00
## density 8 1599 1.00 0.99 1.00
## pH 9 1599 3.31 2.74 4.01
## sulphates 10 1599 0.66 0.33 2.00
## alcohol 11 1599 10.42 8.40 14.90
## quality 12 1599 5.64 3.00 8.00
# 'head()'function displays the first six rows of the dataframe. A useful way to quickly preview the structure and content of a data frame
head(wine)
## fixed.acidity volatile.acidity citric.acid residual.sugar chlorides
## 1 7.4 0.70 0.00 1.9 0.076
## 2 7.8 0.88 0.00 2.6 0.098
## 3 7.8 0.76 0.04 2.3 0.092
## 4 11.2 0.28 0.56 1.9 0.075
## 5 7.4 0.70 0.00 1.9 0.076
## 6 7.4 0.66 0.00 1.8 0.075
## free.sulfur.dioxide total.sulfur.dioxide density pH sulphates alcohol
## 1 11 34 0.9978 3.51 0.56 9.4
## 2 25 67 0.9968 3.20 0.68 9.8
## 3 15 54 0.9970 3.26 0.65 9.8
## 4 17 60 0.9980 3.16 0.58 9.8
## 5 11 34 0.9978 3.51 0.56 9.4
## 6 13 40 0.9978 3.51 0.56 9.4
## quality
## 1 5
## 2 5
## 3 5
## 4 6
## 5 5
## 6 5
# 'str' function shows the datatype of each variable.
str(wine)
## 'data.frame': 1599 obs. of 12 variables:
## $ fixed.acidity : num 7.4 7.8 7.8 11.2 7.4 7.4 7.9 7.3 7.8 7.5 ...
## $ volatile.acidity : num 0.7 0.88 0.76 0.28 0.7 0.66 0.6 0.65 0.58 0.5 ...
## $ citric.acid : num 0 0 0.04 0.56 0 0 0.06 0 0.02 0.36 ...
## $ residual.sugar : num 1.9 2.6 2.3 1.9 1.9 1.8 1.6 1.2 2 6.1 ...
## $ chlorides : num 0.076 0.098 0.092 0.075 0.076 0.075 0.069 0.065 0.073 0.071 ...
## $ free.sulfur.dioxide : num 11 25 15 17 11 13 15 15 9 17 ...
## $ total.sulfur.dioxide: num 34 67 54 60 34 40 59 21 18 102 ...
## $ density : num 0.998 0.997 0.997 0.998 0.998 ...
## $ pH : num 3.51 3.2 3.26 3.16 3.51 3.51 3.3 3.39 3.36 3.35 ...
## $ sulphates : num 0.56 0.68 0.65 0.58 0.56 0.56 0.46 0.47 0.57 0.8 ...
## $ alcohol : num 9.4 9.8 9.8 9.8 9.4 9.4 9.4 10 9.5 10.5 ...
## $ quality : int 5 5 5 6 5 5 5 7 7 5 ...
Clean and transfer the dataset
# Convert the quality variable into a binary target variable with high quality as 1 and low quality as 0
# Wines with quality scores from 7 to 8 are considered high quality and assigned a value of 1. Wines with quality scores from 3 to 6 are considered low quality and assigned a value of 0
wine$quality <- ifelse(wine$quality >= 7, 1, 0)
# check the distribution of the target variable to ensure that we have a balanced dataset
table(wine$quality)
##
## 0 1
## 1382 217
# the 'table'function shows that there are only 217 high quality wines compared to 1382 low quality wines, indicating that the dataset is imbalanced.
barplot(prop.table(table(wine$quality)),
col = rainbow(2),
ylim = c(0, 1),
main = "Class Distribution")
# based on the plot it clearly evident that 84% of the data in one class and the remaining 16% in another class
#change the quality variable from numeric variable to a factor
wine$quality<- as.factor(wine$quality)
class(wine$quality)
## [1] "factor"
Balance the dataset
# There are several techniques to adress class imbalance: undersampling the majority class or oversampling the minority class. These techniques can create a more balanced training set for the predictive model, especially for logistic regression. Balancing the dataset can help prevent the model from being biased towards the majority class and improve its predictive accuracy.
# Logistic regression is more sensitive to class imbalance than decision trees or random forests.
check for correlations between independent variables
# high correlations lead to the below issues:
# Multicollinearity:it becomes difficult to determine the effect of each individual varaible on the outcome.
# Reundancy: overfitting, more complext than necessay, making it more difficult to interpret and less accurate
# remove or combine highly correlated varaibles before building a predictive model/ feature engineering
wine_correlation = subset(wine, select = -c(quality))
library(corrplot)
## Warning: package 'corrplot' was built under R version 4.2.2
## corrplot 0.92 loaded
col <- colorRampPalette(c("red", "white", "blue"))
corrplot(cor(wine_correlation), tl.cex=.7, order="hclust", col=col(50))
# Check the exact correlation coefficient between 'PH' and 'flxed.acidity' due to their high correlation on the correlation heatmap
cor(wine$pH, wine$fixed.acidity)
## [1] -0.6829782
# In the case of a correlation coefficient of -0.6829782 between pH and fixed.acidity, this indicates a moderate negative correlation between the two variables. It may be worth to test whether including both variables in the predictive model will lead to issues.
Split the dataset into training data and test data
set.seed(9)
sample <- sample(c(TRUE, FALSE), nrow(wine), replace=TRUE, prob=c(0.7,0.3))
train <- wine[sample, ]
test <- wine[!sample, ]
Build a logistic regression model
# Fit a logistic regression model with all variables
library(MASS)
##
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
##
## select
glm <- glm(quality ~.,family = 'binomial', data= train)
summary(glm)
##
## Call:
## glm(formula = quality ~ ., family = "binomial", data = train)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.0341 -0.4241 -0.2110 -0.1156 2.9230
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 4.177e+02 1.278e+02 3.268 0.001081 **
## fixed.acidity 4.189e-01 1.480e-01 2.830 0.004654 **
## volatile.acidity -2.621e+00 9.472e-01 -2.767 0.005650 **
## citric.acid 5.761e-01 9.831e-01 0.586 0.557897
## residual.sugar 3.726e-01 9.042e-02 4.121 3.78e-05 ***
## chlorides -6.276e+00 3.609e+00 -1.739 0.082034 .
## free.sulfur.dioxide 1.259e-04 1.527e-02 0.008 0.993418
## total.sulfur.dioxide -1.590e-02 5.871e-03 -2.709 0.006753 **
## density -4.370e+02 1.307e+02 -3.344 0.000825 ***
## pH 1.220e+00 1.193e+00 1.023 0.306340
## sulphates 4.066e+00 6.328e-01 6.425 1.32e-10 ***
## alcohol 6.195e-01 1.517e-01 4.085 4.41e-05 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 914.76 on 1118 degrees of freedom
## Residual deviance: 607.74 on 1107 degrees of freedom
## AIC: 631.74
##
## Number of Fisher Scoring iterations: 6
# backward selection: Remove the least significant variable one step at a time until all remaining features are statistically signficant with p-value <= 0.05
# The reson for removing the least significant variable one step at a time, is to aviod removing a varible that is significant only in combination with another variable. This can result in an underfitted model with poor prediction power
# overfitting occurs when a model is too complext and captures the noise in the data instead of the underlying pattern. backward selection helps to aviod overfitting by removing the features that do not contribute significantly to the model, which reduces the complexity of the model and helps to ensure that it is not overfitted to the data - Simplify the model and avoid overfitting
test_yhat_glm <- predict(glm, test, type='response')
# the confusion table for the logistic regression model with 50% cutoff in test set
# Calculate the true positive rate(TPR) and true negative rate(TNR)
table(test$quality, test_yhat_glm>0.5)
##
## FALSE TRUE
## 0 401 21
## 1 43 15
TPR <- function(y,yhat) {sum(y==1 & yhat ==1)/sum(y==1)}
TNR <- function(y, yhat) {sum(y==0 & yhat ==0)/sum(y==0)}
TPR(test$quality, (test_yhat_glm>0.5))
## [1] 0.2586207
TNR(test$quality, (test_yhat_glm>0.5))
## [1] 0.950237
library('pROC')
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
glm.roc <- roc(test$quality, test_yhat_glm, direction = '<')
## Setting levels: control = 0, case = 1
glm.roc
##
## Call:
## roc.default(response = test$quality, predictor = test_yhat_glm, direction = "<")
##
## Data: test_yhat_glm in 422 controls (test$quality 0) < 58 cases (test$quality 1).
## Area under the curve: 0.8706
#backward selection can be computationally expensive, especially if you have a large number of variables. it could be better to use other methods for feature selection, such as Lasso or Ridge regression
Apply Lasso to logistic model for feature selection
#generate the model matrices
train_lasso <-model.matrix(quality ~., train)
test_lasso<-model.matrix(quality~., test)
dim(train_lasso)
## [1] 1119 12
train_lasso_x <-train_lasso[,-1]
test_lasso_x <-test_lasso[,-1]
# run lasso regression
library(glmnet)
## Loading required package: Matrix
## Loaded glmnet 4.1-4
Build a classification tree
library(rpart)
library(rpart.plot)
# First, we build a large initial classification tree. We can ensure that the tree is large by using a small value for 'CP' = 0.001 which stands for 'complexity parameter'.
# CP, complexity parameter is a tuning parameter that controls the complexity of a decision tree, and can be sued to prune the tree to improve its predictive accuracy on new data
# build the initial tree
formula <- formula(quality~.)
tree <- rpart(formula, data=train, cp =.001, method = 'class')
# view the results of the model
plot(tree,uniform=T,compress=T,margin=.05,branch=0.3)
text(tree, cex=.7, col="navy",use.n=TRUE)
printcp(tree)
##
## Classification tree:
## rpart(formula = formula, data = train, method = "class", cp = 0.001)
##
## Variables actually used in tree construction:
## [1] alcohol chlorides citric.acid
## [4] density fixed.acidity free.sulfur.dioxide
## [7] pH residual.sugar sulphates
## [10] total.sulfur.dioxide volatile.acidity
##
## Root node error: 159/1119 = 0.14209
##
## n= 1119
##
## CP nsplit rel error xerror xstd
## 1 0.1037736 0 1.00000 1.00000 0.073455
## 2 0.0283019 2 0.79245 0.88050 0.069605
## 3 0.0230608 4 0.73585 0.86164 0.068961
## 4 0.0220126 7 0.66667 0.84906 0.068526
## 5 0.0188679 9 0.62264 0.84906 0.068526
## 6 0.0125786 12 0.56604 0.81132 0.067189
## 7 0.0035939 13 0.55346 0.77358 0.065807
## 8 0.0031447 20 0.52830 0.77987 0.066040
## 9 0.0010000 22 0.52201 0.77987 0.066040
plotcp(tree)
# prune the tree using the best value of CP
# the value of CP with least cross-validated error 'xerror' is the optimal value of CP given by the printcp()function
# best cp value
best <- tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"]
best
## [1] 0.00359389
#produce a pruned tree based on the best cp value
tree_prune <- prune(tree, cp=best)
plot(tree_prune,uniform=T,compress=T,margin=.05,branch=0.3)
text(tree_prune, cex=.7, col="navy",use.n=TRUE)
yhat.tree <- predict(tree_prune, test, type = 'prob')[,2]
table(test$quality, yhat.tree>0.5)
##
## FALSE TRUE
## 0 402 20
## 1 35 23
TPR(test$quality, (yhat.tree>0.5))
## [1] 0.3965517
TNR(test$quality, (yhat.tree>0.5))
## [1] 0.9526066
# A TPR of o.4 and a TNR of 0.95 suggest that the pruned tree may be more accurate at prediction negative cases than positive cases due to class imbalance in the data
library('pROC')
tree_prune.roc <- roc(test$quality, yhat.tree, direction = '<')
## Setting levels: control = 0, case = 1
tree_prune.roc
##
## Call:
## roc.default(response = test$quality, predictor = yhat.tree, direction = "<")
##
## Data: yhat.tree in 422 controls (test$quality 0) < 58 cases (test$quality 1).
## Area under the curve: 0.7995
Build a randomForest
library(randomForest)
## Warning: package 'randomForest' was built under R version 4.2.2
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:psych':
##
## outlier
# changed the features in the matrix format
xvars <-names(train)[1:11]
x_train<-as.matrix(train[,xvars])
x_test<- as.matrix(test[,xvars])
y_train<-train$quality
# set the parameters
# set the number of variables to try in each tree: sqrt(p), p is the number of variables in x_train
mtry <- round(ncol(x_train)^.5)
mtry
## [1] 3
# therefore, the number of variable in each tree is 3
# set the number of bootstrap samples
ntree <- 500
# We take 500 bootstrapped sample from the original dataset
# build the random forest
set.seed(9)
rforest <- randomForest(x= x_train, y =y_train, data=train, ntree = ntree, mtry = mtry, importance =TRUE, na.action=na.omit)
summary(rforest)
## Length Class Mode
## call 8 -none- call
## type 1 -none- character
## predicted 1119 factor numeric
## err.rate 1500 -none- numeric
## confusion 6 -none- numeric
## votes 2238 matrix numeric
## oob.times 1119 -none- numeric
## classes 2 -none- character
## importance 44 -none- numeric
## importanceSD 33 -none- numeric
## localImportance 0 -none- NULL
## proximity 0 -none- NULL
## ntree 1 -none- numeric
## mtry 1 -none- numeric
## forest 14 -none- list
## y 1119 factor numeric
## test 0 -none- NULL
## inbag 0 -none- NULL
head(rforest$importance)
## 0 1 MeanDecreaseAccuracy
## fixed.acidity 0.008077915 0.03929650 0.012499454
## volatile.acidity 0.006226011 0.10399440 0.020232770
## citric.acid 0.006593443 0.06851180 0.015327231
## residual.sugar 0.003525264 0.02117277 0.006036433
## chlorides 0.009597397 0.02642499 0.011951925
## free.sulfur.dioxide 0.009323818 0.01859818 0.010646635
## MeanDecreaseGini
## fixed.acidity 20.64530
## volatile.acidity 30.59013
## citric.acid 24.91204
## residual.sugar 17.81895
## chlorides 19.81105
## free.sulfur.dioxide 15.03669
# variable importance ranking
# measure the contribution of each feature to the predictive accuracy. some features may be more important than others in determining the final prediction
imp_rforest <- rforest$importance[,4]
order<- order(imp_rforest, decreasing=T)
imp_rforest<-imp_rforest[order]
par(mar=c(2, 8, 4, 2) + 0.1)
barplot(imp_rforest, col='lavender', horiz=TRUE, las=1, cex.names=.8)
title("Random Forest Variable Importance Plot")
# variable importance in random forest can help understand the relationships between the features and the outcome.
pred.rforest <- predict(rforest, x_test)
table(test$quality, pred.rforest)
## pred.rforest
## 0 1
## 0 411 11
## 1 34 24
yhat.rforest <- predict(rforest, test, type='prob')[,2]
TPR(test$quality, (yhat.rforest>0.5))
## [1] 0.4137931
TNR(test$quality, (yhat.rforest>0.5))
## [1] 0.9739336
rforest.roc <-roc(test$quality, yhat.rforest, direction ='<')
## Setting levels: control = 0, case = 1
rforest.roc
##
## Call:
## roc.default(response = test$quality, predictor = yhat.rforest, direction = "<")
##
## Data: yhat.rforest in 422 controls (test$quality 0) < 58 cases (test$quality 1).
## Area under the curve: 0.9287
plot(rforest.roc, lwd=3)
Model comparison using ROC
plot(rforest.roc, lwd=3)
lines(tree_prune.roc,col ='blue', lwd=3)
lines(glm.roc, col ='orange', lwd=3)