For this analysis, we will use a simulated dataset containing hypothetical medical expenses for patients in the United States. This data was created for this book using demographic statistics from the US Census Bureau.
Using the str() function we see the dataset includes 1,338 examples of beneficiaries currently enrolled in the insurance plan, with seven features indicating characteristics of the patient and the response variable is the total medical expenses charged to the plan for the calendar year.
##Exploring and preparing the data ----
insurance <- read.csv("http://www.sci.csueastbay.edu/~esuess/classes/Statistics_6620/Presentations/ml10/insurance.csv", stringsAsFactors = TRUE)
str(insurance)
'data.frame': 1338 obs. of 7 variables:
$ age : int 19 18 28 33 32 31 46 37 37 60 ...
$ sex : Factor w/ 2 levels "female","male": 1 2 2 2 2 1 1 1 2 1 ...
$ bmi : num 27.9 33.8 33 22.7 28.9 25.7 33.4 27.7 29.8 25.8 ...
$ children: int 0 1 3 0 0 0 1 3 2 0 ...
$ smoker : Factor w/ 2 levels "no","yes": 2 1 1 1 1 1 1 1 1 1 ...
$ region : Factor w/ 4 levels "northeast","northwest",..: 4 3 3 2 2 3 3 2 1 2 ...
$ expenses: num 16885 1726 4449 21984 3867 ...
# summarize the charges variable
summary(insurance$expenses)
Min. 1st Qu. Median Mean 3rd Qu. Max.
1122 4740 9382 13270 16640 63770
# Histogram
hist(insurance$expenses)
## Convert all factors to numeric
indx <- sapply(insurance, is.factor)
lit <- rep(0,3)
y <- 1
for(i in 1:length(indx)){
if(indx[i] == TRUE){
lit[y] <- i
y = y + 1
}
}
insurance[lit] <- lapply(insurance[lit], function(x) as.numeric(as.factor(x)))
str(insurance)
'data.frame': 1338 obs. of 7 variables:
$ age : int 19 18 28 33 32 31 46 37 37 60 ...
$ sex : num 1 2 2 2 2 1 1 1 2 1 ...
$ bmi : num 27.9 33.8 33 22.7 28.9 25.7 33.4 27.7 29.8 25.8 ...
$ children: int 0 1 3 0 0 0 1 3 2 0 ...
$ smoker : num 2 1 1 1 1 1 1 1 1 1 ...
$ region : num 4 3 3 2 2 3 3 2 1 2 ...
$ expenses: num 16885 1726 4449 21984 3867 ...
Before we start the analysis, we must check the assumptions of linear regression.
#Check the assumption of Normality
ins_model <- lm(expenses ~ ., data = insurance)
vit.data <- residuals(ins_model)
shapiro.test(vit.data)
Shapiro-Wilk normality test
data: vit.data
W = 0.89906, p-value < 2.2e-16
We see that dependent varaible is not normal so we need to transform the variable, we we use the log transformation.
## Transformation of Dependent Variable
#install.packages("rcompanion")
library(rcompanion)
insurance$expenses <- transformTukey(insurance$expenses,plotit=TRUE)
lambda W Shapiro.p.value
461 1.5 0.977 8.472e-14
if (lambda > 0){TRANS = x ^ lambda}
if (lambda == 0){TRANS = log(x)}
if (lambda < 0){TRANS = -1 * x ^ lambda}
insurance$expenses <- round(insurance$expenses, 2)
Now that we have satisfied the assumption of normality. We want to check that there is little or no multicollinearity in the data. We can do this by using the correlation matrix. We will only check the numeric variables. And lastly we need to check that linearity relationship between the independent and dependent varaibles. We will check this assumption using scatterplots.
# exploring relationships among features: correlation matrix
cor(insurance[c("age", "bmi", "children", "expenses")])
age bmi children expenses
age 1.0000000 0.10934101 0.04246900 0.5108107
bmi 0.1093410 1.00000000 0.01264471 0.1305575
children 0.0424690 0.01264471 1.00000000 0.1593635
expenses 0.5108107 0.13055753 0.15936347 1.0000000
# visualing relationships among features: scatterplot matrix
pairs(insurance[c("age", "bmi", "children", "expenses")])
# more informative scatterplot matrix
#install.packages("psych")
library(psych)
pairs.panels(insurance[c("age", "bmi", "children", "expenses")])
From the above graphs we can see that our model meets all of its assumption now we can move into the analysis part of the report.
## Step 3: Training a model on the data ----
ins_model <- lm(expenses ~ ., data = insurance) # this is equivalent to above
# see the estimated beta coefficients
ins_model
Call:
lm(formula = expenses ~ ., data = insurance)
Coefficients:
(Intercept) age sex bmi children smoker region age2
1.142e+00 1.559e-03 -2.249e-03 2.224e-04 2.935e-03 4.874e-02 -1.339e-03 -6.373e-06
bmi30
2.806e-03
## Step 4: Evaluating model performance ----
# see more detail about the estimated beta coefficients
summary(ins_model)
Call:
lm(formula = expenses ~ ., data = insurance)
Residuals:
Min 1Q Median 3Q Max
-0.032944 -0.007224 -0.002093 0.003362 0.071520
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 1.142e+00 4.533e-03 251.980 < 2e-16 ***
age 1.559e-03 1.896e-04 8.224 4.65e-16 ***
sex -2.249e-03 7.737e-04 -2.907 0.003708 **
bmi 2.224e-04 1.062e-04 2.095 0.036376 *
children 2.935e-03 3.352e-04 8.754 < 2e-16 ***
smoker 4.874e-02 9.573e-04 50.919 < 2e-16 ***
region -1.339e-03 3.532e-04 -3.791 0.000157 ***
age2 -6.373e-06 2.365e-06 -2.695 0.007127 **
bmi30 2.806e-03 1.287e-03 2.180 0.029441 *
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 0.01408 on 1329 degrees of freedom
Multiple R-squared: 0.761, Adjusted R-squared: 0.7596
F-statistic: 529.1 on 8 and 1329 DF, p-value: < 2.2e-16
## Step 5: Improving model performance ----
# add a higher-order "age" term
insurance$age2 <- insurance$age^2
# add an indicator for BMI >= 30
insurance$bmi30 <- ifelse(insurance$bmi >= 30, 1, 0)
# create final model
ins_model2 <- lm(expenses ~ age + age2 + children + bmi + sex +
bmi30*smoker + region, data = insurance)
summary(ins_model2)
Call:
lm(formula = expenses ~ age + age2 + children + bmi + sex + bmi30 *
smoker + region, data = insurance)
Residuals:
Min 1Q Median 3Q Max
-0.029997 -0.006949 -0.002132 0.002483 0.073918
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 1.157e+00 4.516e-03 256.139 < 2e-16 ***
age 1.554e-03 1.810e-04 8.585 < 2e-16 ***
age2 -6.233e-06 2.258e-06 -2.760 0.005855 **
children 2.985e-03 3.201e-04 9.325 < 2e-16 ***
bmi 1.902e-04 1.014e-04 1.876 0.060883 .
sex -2.597e-03 7.394e-04 -3.512 0.000459 ***
bmi30 -2.193e-02 2.497e-03 -8.783 < 2e-16 ***
smoker 3.777e-02 1.329e-03 28.431 < 2e-16 ***
region -1.433e-03 3.374e-04 -4.247 2.31e-05 ***
bmi30:smoker 2.081e-02 1.829e-03 11.380 < 2e-16 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 0.01345 on 1328 degrees of freedom
Multiple R-squared: 0.7823, Adjusted R-squared: 0.7808
F-statistic: 530.2 on 9 and 1328 DF, p-value: < 2.2e-16
#Compare the two models
anova(ins_model,ins_model2)
Analysis of Variance Table
Model 1: expenses ~ age + sex + bmi + children + smoker + region + age2 +
bmi30
Model 2: expenses ~ age + age2 + children + bmi + sex + bmi30 * smoker +
region
Res.Df RSS Df Sum of Sq F Pr(>F)
1 1329 0.26364
2 1328 0.24021 1 0.023426 129.51 < 2.2e-16 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
The MASS library contains the Boston data set, which records medv (median house value) for 506 neighborhoods around Boston. We will seek to predict medv using 13 predictors such as rm (average number of rooms per house), age (average age of houses), and lstat (percent of households with low socioeconomic status).
library(MASS)
data(Boston)
names(Boston)
[1] "crim" "zn" "indus" "chas" "nox" "rm" "age" "dis" "rad" "tax"
[11] "ptratio" "black" "lstat" "medv"
To find out more about the data set, we can type ?Boston. We will start by using the lm() function to fit a simple linear regression model, with medv as the response and lstat as the predictor. The basic lm() syntax is lm(y∼x,data), where y is the response, x is the predictor, and data is the data set in which these two variables are kept.
If we type lm.fit, some basic information about the model is output. For more detailed information, we use summary(lm.fit). This gives us p- values and standard errors for the coefficients, as well as the R2 statistic and F-statistic for the model.
attach(Boston)
lm.fit=lm(medv ~ lstat)
lm.fit
Call:
lm(formula = medv ~ lstat)
Coefficients:
(Intercept) lstat
34.55 -0.95
summary(lm.fit)
Call:
lm(formula = medv ~ lstat)
Residuals:
Min 1Q Median 3Q Max
-15.168 -3.990 -1.318 2.034 24.500
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 34.55384 0.56263 61.41 <2e-16 ***
lstat -0.95005 0.03873 -24.53 <2e-16 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 6.216 on 504 degrees of freedom
Multiple R-squared: 0.5441, Adjusted R-squared: 0.5432
F-statistic: 601.6 on 1 and 504 DF, p-value: < 2.2e-16
We can use the names() function in order to find out what other pieces of information are stored in lm.fit. Although we can extract these quan- tities by name—e.g. lm.fit$coefficients—it is safer to use the extractor functions like coef() to access them.
In order to obtain a confidence interval for the coefficient estimates, we can use the confint() command.
The predict() function can be used to produce confidence intervals and prediction intervals for the prediction of medv for a given value of lstat.
names(lm.fit)
[1] "coefficients" "residuals" "effects" "rank" "fitted.values" "assign"
[7] "qr" "df.residual" "xlevels" "call" "terms" "model"
coef(lm.fit)
(Intercept) lstat
34.5538409 -0.9500494
confint(lm.fit)
2.5 % 97.5 %
(Intercept) 33.448457 35.6592247
lstat -1.026148 -0.8739505
predict(lm.fit,data.frame(lstat=c(5,10,15)), interval ="confidence")
fit lwr upr
1 29.80359 29.00741 30.59978
2 25.05335 24.47413 25.63256
3 20.30310 19.73159 20.87461
For instance, the 95 % confidence interval associated with a lstat value of 10 is (24.47, 25.63), and the 95 % prediction interval is (12.828, 37.28). As expected, the confidence and prediction intervals are centered around the same point (a predicted value of 25.05 for medv when lstat equals 10), but the latter are substantially wider. We will now plot medv and lstat along with the least squares regression line using the plot() and abline() functions.
plot(lstat ,medv)
abline(lm.fit)
abline(lm.fit,lwd=3)
abline(lm.fit,lwd=3,col="red")
There is some evidence for non-linearity in the relationship between lstat and medv. We will explore this issue later in this lab.
The abline() function can be used to draw any line, not just the least squares regression line. To draw a line with intercept a and slope b, we type abline(a,b). Below we experiment with some additional settings for plotting lines and points. The lwd=3 command causes the width of the regression line to be increased by a factor of 3; this works for the plot() and lines() functions also. We can also use the pch option to create different plotting symbols.
par(mfrow=c(2,2))
plot(lstat,medv,col="red")
plot(lstat,medv,pch=20)
plot(lstat,medv,pch="+")
plot(1:20,1:20,pch=1:20)
Next we examine some diagnostic plots, several of which were discussed in Section 3.3.3. Four diagnostic plots are automatically produced by applying the plot() function directly to the output from lm(). In general, this command will produce one plot at a time, and hitting Enter will generate the next plot. However, it is often convenient to view all four plots together. We can achieve this by using the par() function, which tells R to split the display screen into separate panels so that multiple plots can be viewed si- multaneously. For example, par(mfrow=c(2,2)) divides the plotting region into a 2 × 2 grid of panels.
par(mfrow=c(2,2))
plot(lm.fit)
Alternatively, we can compute the residuals from a linear regression fit using the residuals() function. The function rstudent() will return the studentized residuals, and we can use this function to plot the residuals against the fitted values.
par(mfrow=c(2,2))
plot(predict(lm.fit), residuals(lm.fit))
plot(predict(lm.fit), rstudent(lm.fit))
plot(hatvalues (lm.fit))
#which.max(hatvalues (lm.fit))
On the basis of the residual plots, there is some evidence of non-linearity. Leverage statistics can be computed for any number of predictors using the hatvalues() function.
In order to fit a multiple linear regression model using least squares, we again use the lm() function. The syntax lm(y∼x1+x2+x3) is used to fit a model with three predictors, x1, x2, and x3. The summary() function now outputs the regression coefficients for all the predictors.
lm.fit = lm(medv ~ lstat+age,data=Boston)
summary(lm.fit)
Call:
lm(formula = medv ~ lstat + age, data = Boston)
Residuals:
Min 1Q Median 3Q Max
-15.981 -3.978 -1.283 1.968 23.158
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 33.22276 0.73085 45.458 < 2e-16 ***
lstat -1.03207 0.04819 -21.416 < 2e-16 ***
age 0.03454 0.01223 2.826 0.00491 **
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 6.173 on 503 degrees of freedom
Multiple R-squared: 0.5513, Adjusted R-squared: 0.5495
F-statistic: 309 on 2 and 503 DF, p-value: < 2.2e-16
##Uses all 13 variables contained in the datset
lm.fit=lm(medv ~ .,data=Boston)
summary(lm.fit)
Call:
lm(formula = medv ~ ., data = Boston)
Residuals:
Min 1Q Median 3Q Max
-15.595 -2.730 -0.518 1.777 26.199
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 3.646e+01 5.103e+00 7.144 3.28e-12 ***
crim -1.080e-01 3.286e-02 -3.287 0.001087 **
zn 4.642e-02 1.373e-02 3.382 0.000778 ***
indus 2.056e-02 6.150e-02 0.334 0.738288
chas 2.687e+00 8.616e-01 3.118 0.001925 **
nox -1.777e+01 3.820e+00 -4.651 4.25e-06 ***
rm 3.810e+00 4.179e-01 9.116 < 2e-16 ***
age 6.922e-04 1.321e-02 0.052 0.958229
dis -1.476e+00 1.995e-01 -7.398 6.01e-13 ***
rad 3.060e-01 6.635e-02 4.613 5.07e-06 ***
tax -1.233e-02 3.760e-03 -3.280 0.001112 **
ptratio -9.527e-01 1.308e-01 -7.283 1.31e-12 ***
black 9.312e-03 2.686e-03 3.467 0.000573 ***
lstat -5.248e-01 5.072e-02 -10.347 < 2e-16 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 4.745 on 492 degrees of freedom
Multiple R-squared: 0.7406, Adjusted R-squared: 0.7338
F-statistic: 108.1 on 13 and 492 DF, p-value: < 2.2e-16
We can access the individual components of a summary object by name (type ?summary.lm to see what is available). Hence summary(lm.fit)\(r.sq gives us the R2, and summary(lm.fit)\)sigma gives us the RSE. The vif() function, part of the car package, can be used to compute variance inflation factors. Most VIF’s are low to moderate for this data. The car package is not part of the base R installation so it must be downloaded the first time you use it via the install.packages option in R.
library(car)
vif(lm.fit)
crim zn indus chas nox rm age dis rad tax ptratio black
1.792192 2.298758 3.991596 1.073995 4.393720 1.933744 3.100826 3.955945 7.484496 9.008554 1.799084 1.348521
lstat
2.941491
What if we would like to perform a regression using all of the variables but one? For example, in the above regression output, age has a high p-value. So we may wish to run a regression excluding this predictor. The following syntax results in a regression using all predictors except age.
# Alternatively, the update() function can be used.
lm.fit1=update(lm.fit, ~ .-age)
We will be using a regressin tree algorithm to rate wines. We will use a wine dataset from the UCI Machine Learning Data Repository. For this analysis we will be only be using the redwine data.
summary(wine)
fixed acidity volatile acidity citric acid residual sugar chlorides free sulfur dioxide
Min. : 4.600 Min. :0.1200 Min. :0.0000 Min. : 0.900 Min. :0.01200 Min. : 1.00
1st Qu.: 7.100 1st Qu.:0.3900 1st Qu.:0.0900 1st Qu.: 1.900 1st Qu.:0.07000 1st Qu.: 7.00
Median : 7.900 Median :0.5200 Median :0.2600 Median : 2.200 Median :0.07900 Median :14.00
Mean : 8.322 Mean :0.5277 Mean :0.2713 Mean : 2.537 Mean :0.08746 Mean :15.83
3rd Qu.: 9.200 3rd Qu.:0.6400 3rd Qu.:0.4200 3rd Qu.: 2.600 3rd Qu.:0.09000 3rd Qu.:21.00
Max. :15.900 Max. :1.5800 Max. :1.0000 Max. :15.500 Max. :0.61100 Max. :72.00
total sulfur dioxide density pH sulphates alcohol quality
Min. : 6.00 Min. :0.9901 Min. :2.740 Min. :0.3300 Min. : 8.40 Min. :3.000
1st Qu.: 22.00 1st Qu.:0.9956 1st Qu.:3.210 1st Qu.:0.5500 1st Qu.: 9.50 1st Qu.:5.000
Median : 38.00 Median :0.9968 Median :3.310 Median :0.6200 Median :10.20 Median :6.000
Mean : 46.43 Mean :0.9967 Mean :3.311 Mean :0.6584 Mean :10.42 Mean :5.637
3rd Qu.: 62.00 3rd Qu.:0.9978 3rd Qu.:3.400 3rd Qu.:0.7300 3rd Qu.:11.10 3rd Qu.:6.000
Max. :289.00 Max. :1.0037 Max. :4.010 Max. :2.0000 Max. :14.90 Max. :8.000
Using the str() function we see that there are 11 independent variables and 1 response variable, quality. Also we see that quailty ranking ranges from 3 to 8. Also, we can see that the average rating is about 6. From the histogram the data appears to follow a normal distribution.
Now we wil divide the dataset into training and testing.
length(wine$`fixed acidity`)
[1] 1597
wine_train <- wine[1:1300, ]
wine_test <- wine[1301:1597, ]
Now that we have partitioned the data, be can begin training a regression tree model. For this analysis, we will be using the rpart package.
For the model, we will specify quality as the response variable and use all other independent variables in th model.
## Step 3: Training a model on the data ----
# regression tree using rpart
library(rpart)
m.rpart <- rpart(quality ~ ., data = wine_train)
# get basic information about the tree
m.rpart
n= 1300
node), split, n, deviance, yval
* denotes terminal node
1) root 1300 834.10080 5.643846
2) alcohol< 11.45 1065 560.12580 5.489202
4) sulphates< 0.645 639 260.27540 5.292645
8) volatile acidity>=0.925 29 21.17241 4.551724 *
9) volatile acidity< 0.925 610 222.42620 5.327869
18) alcohol< 9.975 357 96.66106 5.207283 *
19) alcohol>=9.975 253 113.24900 5.498024
38) free sulfur dioxide< 7.5 78 38.71795 5.205128 *
39) free sulfur dioxide>=7.5 175 64.85714 5.628571 *
5) sulphates>=0.645 426 238.13150 5.784038
10) alcohol< 9.95 170 78.35294 5.470588
20) free sulfur dioxide>=22.5 43 14.97674 5.023256 *
21) free sulfur dioxide< 22.5 127 51.85827 5.622047
42) fixed acidity< 11.35 112 35.96429 5.517857 *
43) fixed acidity>=11.35 15 5.60000 6.400000 *
11) alcohol>=9.95 256 131.98440 5.992188
22) volatile acidity>=0.405 142 60.23239 5.781690 *
23) volatile acidity< 0.405 114 57.62281 6.254386 *
3) alcohol>=11.45 235 133.08090 6.344681
6) sulphates< 0.635 103 55.96117 6.019417
12) pH>=3.265 70 31.78571 5.785714 *
13) pH< 3.265 33 12.24242 6.515152 *
7) sulphates>=0.635 132 57.71970 6.598485 *
# get more detailed information about the tree
summary(m.rpart)
Call:
rpart(formula = quality ~ ., data = wine_train)
n= 1300
CP nsplit rel error xerror xstd
1 0.16891736 0 1.0000000 1.0017698 0.04214281
2 0.07399458 1 0.8310826 0.8473264 0.03867538
3 0.03332228 2 0.7570881 0.7899171 0.03677830
4 0.02325857 3 0.7237658 0.7665608 0.03643889
5 0.01999373 4 0.7005072 0.7684519 0.03670371
6 0.01693941 5 0.6805135 0.7537495 0.03590915
7 0.01500556 6 0.6635741 0.7488627 0.03562909
8 0.01430646 7 0.6485685 0.7495697 0.03559558
9 0.01380880 8 0.6342621 0.7479410 0.03540105
10 0.01234141 9 0.6204533 0.7370872 0.03501255
11 0.01159802 10 0.6081118 0.7274122 0.03468548
12 0.01000000 11 0.5965138 0.7263393 0.03466762
Variable importance
alcohol sulphates volatile acidity density fixed acidity citric acid
34 16 10 9 8 6
pH free sulfur dioxide chlorides total sulfur dioxide residual sugar
5 4 4 2 1
Node number 1: 1300 observations, complexity param=0.1689174
mean=5.643846, MSE=0.641616
left son=2 (1065 obs) right son=3 (235 obs)
Primary splits:
alcohol < 11.45 to the left, improve=0.16891740, (0 missing)
sulphates < 0.645 to the left, improve=0.12396980, (0 missing)
volatile acidity < 0.425 to the right, improve=0.10567410, (0 missing)
citric acid < 0.295 to the left, improve=0.06995000, (0 missing)
density < 0.99537 to the right, improve=0.06632475, (0 missing)
Surrogate splits:
density < 0.994185 to the right, agree=0.874, adj=0.302, (0 split)
fixed acidity < 5.5 to the right, agree=0.833, adj=0.077, (0 split)
chlorides < 0.0525 to the right, agree=0.832, adj=0.068, (0 split)
pH < 3.695 to the left, agree=0.826, adj=0.038, (0 split)
volatile acidity < 0.14 to the right, agree=0.822, adj=0.013, (0 split)
Node number 2: 1065 observations, complexity param=0.07399458
mean=5.489202, MSE=0.5259397
left son=4 (639 obs) right son=5 (426 obs)
Primary splits:
sulphates < 0.645 to the left, improve=0.11018760, (0 missing)
volatile acidity < 0.405 to the right, improve=0.09137471, (0 missing)
alcohol < 9.975 to the left, improve=0.08589463, (0 missing)
citric acid < 0.295 to the left, improve=0.04404844, (0 missing)
total sulfur dioxide < 83.5 to the right, improve=0.03991942, (0 missing)
Surrogate splits:
citric acid < 0.395 to the left, agree=0.682, adj=0.204, (0 split)
volatile acidity < 0.4175 to the right, agree=0.681, adj=0.202, (0 split)
fixed acidity < 10.35 to the left, agree=0.660, adj=0.150, (0 split)
pH < 3.075 to the right, agree=0.636, adj=0.089, (0 split)
alcohol < 10.525 to the left, agree=0.636, adj=0.089, (0 split)
Node number 3: 235 observations, complexity param=0.02325857
mean=6.344681, MSE=0.5663015
left son=6 (103 obs) right son=7 (132 obs)
Primary splits:
sulphates < 0.635 to the left, improve=0.14577600, (0 missing)
citric acid < 0.325 to the left, improve=0.09163546, (0 missing)
fixed acidity < 7.75 to the left, improve=0.08874421, (0 missing)
pH < 3.375 to the right, improve=0.06101447, (0 missing)
volatile acidity < 0.425 to the right, improve=0.05817961, (0 missing)
Surrogate splits:
fixed acidity < 7.45 to the left, agree=0.698, adj=0.311, (0 split)
citric acid < 0.285 to the left, agree=0.685, adj=0.282, (0 split)
density < 0.99411 to the left, agree=0.668, adj=0.243, (0 split)
volatile acidity < 0.5925 to the right, agree=0.630, adj=0.155, (0 split)
pH < 3.415 to the right, agree=0.630, adj=0.155, (0 split)
Node number 4: 639 observations, complexity param=0.01999373
mean=5.292645, MSE=0.4073168
left son=8 (29 obs) right son=9 (610 obs)
Primary splits:
volatile acidity < 0.925 to the right, improve=0.06407361, (0 missing)
sulphates < 0.575 to the left, improve=0.05749397, (0 missing)
alcohol < 9.975 to the left, improve=0.03427930, (0 missing)
pH < 3.425 to the right, improve=0.02090904, (0 missing)
total sulfur dioxide < 98.5 to the right, improve=0.01636707, (0 missing)
Surrogate splits:
total sulfur dioxide < 149.5 to the right, agree=0.956, adj=0.034, (0 split)
Node number 5: 426 observations, complexity param=0.03332228
mean=5.784038, MSE=0.558994
left son=10 (170 obs) right son=11 (256 obs)
Primary splits:
alcohol < 9.95 to the left, improve=0.11671760, (0 missing)
volatile acidity < 0.405 to the right, improve=0.10992940, (0 missing)
chlorides < 0.0965 to the right, improve=0.09323926, (0 missing)
total sulfur dioxide < 50.5 to the right, improve=0.09215330, (0 missing)
density < 0.996225 to the right, improve=0.05193817, (0 missing)
Surrogate splits:
chlorides < 0.1045 to the right, agree=0.678, adj=0.194, (0 split)
sulphates < 0.975 to the right, agree=0.646, adj=0.112, (0 split)
volatile acidity < 0.565 to the right, agree=0.636, adj=0.088, (0 split)
pH < 3.045 to the left, agree=0.631, adj=0.076, (0 split)
residual sugar < 1.85 to the left, agree=0.629, adj=0.071, (0 split)
Node number 6: 103 observations, complexity param=0.01430646
mean=6.019417, MSE=0.5433123
left son=12 (70 obs) right son=13 (33 obs)
Primary splits:
pH < 3.265 to the right, improve=0.2132376, (0 missing)
citric acid < 0.445 to the left, improve=0.1525071, (0 missing)
volatile acidity < 0.495 to the right, improve=0.1354948, (0 missing)
free sulfur dioxide < 31.5 to the left, improve=0.1332219, (0 missing)
fixed acidity < 6.55 to the left, improve=0.1044414, (0 missing)
Surrogate splits:
citric acid < 0.335 to the left, agree=0.874, adj=0.606, (0 split)
fixed acidity < 7.8 to the left, agree=0.864, adj=0.576, (0 split)
volatile acidity < 0.385 to the right, agree=0.806, adj=0.394, (0 split)
chlorides < 0.0995 to the left, agree=0.748, adj=0.212, (0 split)
free sulfur dioxide < 34 to the left, agree=0.748, adj=0.212, (0 split)
Node number 7: 132 observations
mean=6.598485, MSE=0.4372704
Node number 8: 29 observations
mean=4.551724, MSE=0.7300832
Node number 9: 610 observations, complexity param=0.01500556
mean=5.327869, MSE=0.3646332
left son=18 (357 obs) right son=19 (253 obs)
Primary splits:
alcohol < 9.975 to the left, improve=0.05627103, (0 missing)
sulphates < 0.575 to the left, improve=0.05147264, (0 missing)
volatile acidity < 0.6525 to the right, improve=0.03203148, (0 missing)
total sulfur dioxide < 98.5 to the right, improve=0.02433496, (0 missing)
density < 0.99569 to the right, improve=0.01833855, (0 missing)
Surrogate splits:
density < 0.995805 to the right, agree=0.679, adj=0.225, (0 split)
total sulfur dioxide < 37.5 to the right, agree=0.639, adj=0.130, (0 split)
fixed acidity < 6.95 to the right, agree=0.623, adj=0.091, (0 split)
chlorides < 0.0685 to the right, agree=0.621, adj=0.087, (0 split)
sulphates < 0.595 to the left, agree=0.615, adj=0.071, (0 split)
Node number 10: 170 observations, complexity param=0.0138088
mean=5.470588, MSE=0.4608997
left son=20 (43 obs) right son=21 (127 obs)
Primary splits:
free sulfur dioxide < 22.5 to the right, improve=0.14700060, (0 missing)
fixed acidity < 11.8 to the left, improve=0.14369840, (0 missing)
volatile acidity < 0.3175 to the right, improve=0.12410440, (0 missing)
total sulfur dioxide < 46.5 to the right, improve=0.12406210, (0 missing)
chlorides < 0.0955 to the right, improve=0.07724758, (0 missing)
Surrogate splits:
total sulfur dioxide < 66.5 to the right, agree=0.865, adj=0.465, (0 split)
residual sugar < 3.25 to the right, agree=0.800, adj=0.209, (0 split)
density < 1.0009 to the right, agree=0.771, adj=0.093, (0 split)
volatile acidity < 0.855 to the right, agree=0.765, adj=0.070, (0 split)
sulphates < 1.6 to the right, agree=0.753, adj=0.023, (0 split)
Node number 11: 256 observations, complexity param=0.01693941
mean=5.992188, MSE=0.515564
left son=22 (142 obs) right son=23 (114 obs)
Primary splits:
volatile acidity < 0.405 to the right, improve=0.10705190, (0 missing)
total sulfur dioxide < 54.5 to the right, improve=0.09371862, (0 missing)
residual sugar < 3.8 to the right, improve=0.04513882, (0 missing)
chlorides < 0.0975 to the right, improve=0.04398857, (0 missing)
pH < 3.48 to the right, improve=0.03639320, (0 missing)
Surrogate splits:
citric acid < 0.305 to the left, agree=0.719, adj=0.368, (0 split)
sulphates < 0.765 to the left, agree=0.621, adj=0.149, (0 split)
chlorides < 0.0675 to the right, agree=0.617, adj=0.140, (0 split)
residual sugar < 1.85 to the right, agree=0.602, adj=0.105, (0 split)
fixed acidity < 7.55 to the left, agree=0.590, adj=0.079, (0 split)
Node number 12: 70 observations
mean=5.785714, MSE=0.4540816
Node number 13: 33 observations
mean=6.515152, MSE=0.3709826
Node number 18: 357 observations
mean=5.207283, MSE=0.2707593
Node number 19: 253 observations, complexity param=0.01159802
mean=5.498024, MSE=0.4476246
left son=38 (78 obs) right son=39 (175 obs)
Primary splits:
free sulfur dioxide < 7.5 to the left, improve=0.08542167, (0 missing)
total sulfur dioxide < 14.5 to the left, improve=0.05140778, (0 missing)
sulphates < 0.585 to the left, improve=0.04450115, (0 missing)
volatile acidity < 0.655 to the right, improve=0.02790290, (0 missing)
pH < 3.405 to the right, improve=0.02470002, (0 missing)
Surrogate splits:
total sulfur dioxide < 16.5 to the left, agree=0.881, adj=0.615, (0 split)
alcohol < 11.35 to the right, agree=0.723, adj=0.103, (0 split)
pH < 3.55 to the right, agree=0.711, adj=0.064, (0 split)
sulphates < 0.45 to the left, agree=0.711, adj=0.064, (0 split)
chlorides < 0.1445 to the right, agree=0.708, adj=0.051, (0 split)
Node number 20: 43 observations
mean=5.023256, MSE=0.3482964
Node number 21: 127 observations, complexity param=0.01234141
mean=5.622047, MSE=0.4083328
left son=42 (112 obs) right son=43 (15 obs)
Primary splits:
fixed acidity < 11.35 to the left, improve=0.19850220, (0 missing)
density < 0.99716 to the left, improve=0.16814750, (0 missing)
volatile acidity < 0.3175 to the right, improve=0.13160050, (0 missing)
alcohol < 9.85 to the left, improve=0.10955790, (0 missing)
pH < 2.99 to the right, improve=0.09450066, (0 missing)
Surrogate splits:
volatile acidity < 0.235 to the right, agree=0.898, adj=0.133, (0 split)
density < 0.99965 to the left, agree=0.898, adj=0.133, (0 split)
pH < 2.89 to the right, agree=0.898, adj=0.133, (0 split)
citric acid < 0.71 to the left, agree=0.890, adj=0.067, (0 split)
Node number 22: 142 observations
mean=5.78169, MSE=0.4241718
Node number 23: 114 observations
mean=6.254386, MSE=0.5054632
Node number 38: 78 observations
mean=5.205128, MSE=0.496384
Node number 39: 175 observations
mean=5.628571, MSE=0.3706122
Node number 42: 112 observations
mean=5.517857, MSE=0.3211097
Node number 43: 15 observations
mean=6.4, MSE=0.3733333
# use the rpart.plot package to create a visualization
par(mfrow=c(2,1))
#install.packages("rpart.plot")
library(rpart.plot)
# a basic decision tree diagram
rpart.plot(m.rpart, digits = 3)
# a few adjustments to the diagram
rpart.plot(m.rpart, digits = 4, fallen.leaves = TRUE, type = 3, extra = 101)
Now we want to evaluate the model. We will use the predict() function on the test dataset to check the the accauracy of the models predictability.
## Step 4: Evaluate model performance ----
# generate predictions for the testing dataset
p.rpart <- predict(m.rpart, wine_test)
# compare the distribution of predicted values vs. actual values
summary(p.rpart)
Min. 1st Qu. Median Mean 3rd Qu. Max.
4.552 5.207 5.518 5.602 5.782 6.598
summary(wine_test$quality)
Min. 1st Qu. Median Mean 3rd Qu. Max.
3.000 5.000 6.000 5.606 6.000 8.000
Looking at the summary statistics for both predicted values and the oberserved values we see that they are similar but variance suggests that the model is not indentifying the extreme cases.
One way to gauge the model’s performance to check the correlation between the predicted and actual values.
# compare the correlation
cor(p.rpart, wine_test$quality)
[1] 0.604005
A correlation of 0.604005 is ok but we can do better.
Another way to measure the model’s performance is using the mean absolute errors. The formula is \(MAE = \frac{\sum_{i=1}^{n} \mid e_i \mid}{n}\), where n indicates the number of predictions and \(e_i\) indicates the error for prediction i.
# function to calculate the mean absolute error
MAE <- function(actual, predicted) {
mean(abs(actual - predicted))
}
# mean absolute error between predicted and actual values
MAE(p.rpart, wine_test$quality)
[1] 0.5464075
# mean absolute error between actual values and mean value
mean(wine_train$quality) # result = 5.87
[1] 5.643846
MAE(5.87, wine_test$quality)
[1] 0.7000337
We see that our MAE for our prediction is 0.5464075 . Which means that on average the difference between our prediction and the true value is about 0.5464075. Also above we see the mean quality of wine is 5.6438462. Using the expected value of vine quality for every sample, we see that our MSE is 0.7000337. We see that our model is better at prediciting the quality than just using the expected value. That good but let’s see if we can do better.
One way to improve the model is to use the M5’ algorithm from the RWeka package. I am currenly trying to get the package to work on my computer, in the meantime I will use a Boosted Regression tree approach.
## Step 5: Improving model performance ----
#Train the BRT Model
#install.packages("gbm")
library(gbm)
gbm.model <- gbm(formula = quality ~ ., data = wine_train, n.trees = 1000, shrinkage = .01,bag.fraction = .9, cv.folds = 10, n.minobsinnode = 20)
Distribution not specified, assuming gaussian ...
gbm.model
gbm(formula = quality ~ ., data = wine_train, n.trees = 1000,
n.minobsinnode = 20, shrinkage = 0.01, bag.fraction = 0.9,
cv.folds = 10)
A gradient boosted model with gaussian loss function.
1000 iterations were performed.
The best cross-validation iteration was 1000.
There were 11 predictors of which 10 had non-zero influence.
Now we want to see how much of an improvement we made.
gbmTrainPredictions = predict(object = gbm.model,
newdata = wine_test,
n.trees = 1000,
type = "response")
#summary statistics about the predictions
summary(gbmTrainPredictions)
Min. 1st Qu. Median Mean 3rd Qu. Max.
4.723 5.307 5.506 5.603 5.883 6.744
# compare the correlation
cor(gbmTrainPredictions, wine_test$quality)
[1] 0.6487095
#mean absolute error of predicted and true values
#(uses a custom function defined above)
MAE(wine_test$quality, gbmTrainPredictions)
[1] 0.5168549
When computing the correlation, we see that we increase it to 0.6487095 compared to the previous model which was 0.604005. The method was also able to reduce the MAE as well.
Below is the code from the RWeka package.
//Step 5: Improving model performance —- / train a M5’ Model Tree install.packages(“rJava”,type=‘source’) install.packages(“RWeka”) library(RWeka) m.m5p <- M5P(quality ~ ., data = wine_train)
/ display the tree m.m5p
/ get a summary of the model’s performance summary(m.m5p)
/ generate predictions for the model p.m5p <- predict(m.m5p, wine_test)
/summary statistics about the predictions summary(p.m5p)
/correlation between the predicted and true values cor(p.m5p, wine_test$quality)
/mean absolute error of predicted and true values /(uses a custom function defined above) MAE(wine_test$quality, p.m5p)