Regression Methods Part 1: Linear Regression Understanding regression
getwd()
## [1] "/cloud/project"
Load Data
launch <- read.csv("challenger.csv")
Estimate beta manually
b <- cov(launch$temperature, launch$distress_ct) / var(launch$temperature)
b
## [1] -0.04753968
Estimate alpha manually
a <- mean(launch$distress_ct) - b * mean(launch$temperature)
a
## [1] 3.698413
Calculate the correlation of launch data
r <- cov(launch$temperature, launch$distress_ct) /
(sd(launch$temperature) * sd(launch$distress_ct))
r
## [1] -0.5111264
cor(launch$temperature, launch$distress_ct)
## [1] -0.5111264
Computing the slope using correlation
r * (sd(launch$distress_ct) / sd(launch$temperature))
## [1] -0.04753968
Confirming the regression line using the lm function (not in text)
model <- lm(distress_ct ~ temperature, data = launch)
model
##
## Call:
## lm(formula = distress_ct ~ temperature, data = launch)
##
## Coefficients:
## (Intercept) temperature
## 3.69841 -0.04754
Summary of the model
summary(model)
##
## Call:
## lm(formula = distress_ct ~ temperature, data = launch)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.5608 -0.3944 -0.0854 0.1056 1.8671
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 3.69841 1.21951 3.033 0.00633 **
## temperature -0.04754 0.01744 -2.725 0.01268 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.5774 on 21 degrees of freedom
## Multiple R-squared: 0.2613, Adjusted R-squared: 0.2261
## F-statistic: 7.426 on 1 and 21 DF, p-value: 0.01268
temperature absolute t-value is greater than 2, p-value is less than 0.05 - meaning temperature is significant and the null hypothesis should be denied.
Creating a simple multiple regression function
reg <- function(y, x) {
x <- as.matrix(x)
x <- cbind(Intercept = 1, x)
b <- solve(t(x) %*% x) %*% t(x) %*% y
colnames(b) <- "estimate"
print(b)
}
Examine the launch data
str(launch)
## 'data.frame': 23 obs. of 4 variables:
## $ distress_ct : int 0 1 0 0 0 0 0 0 1 1 ...
## $ temperature : int 66 70 69 68 67 72 73 70 57 63 ...
## $ field_check_pressure: int 50 50 50 50 50 50 100 100 200 200 ...
## $ flight_num : int 1 2 3 4 5 6 7 8 9 10 ...
Test regression model with simple linear regression
reg(y = launch$distress_ct, x = launch[2])
## estimate
## Intercept 3.69841270
## temperature -0.04753968
Use regression model with multiple regression
reg(y = launch$distress_ct, x = launch[2:4])
## estimate
## Intercept 3.527093383
## temperature -0.051385940
## field_check_pressure 0.001757009
## flight_num 0.014292843
Confirming the multiple regression result using the lm function (not in text)
model <- lm(distress_ct ~ temperature + field_check_pressure + flight_num, data = launch)
model
##
## Call:
## lm(formula = distress_ct ~ temperature + field_check_pressure +
## flight_num, data = launch)
##
## Coefficients:
## (Intercept) temperature field_check_pressure
## 3.527093 -0.051386 0.001757
## flight_num
## 0.014293
summary(model)
##
## Call:
## lm(formula = distress_ct ~ temperature + field_check_pressure +
## flight_num, data = launch)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.65003 -0.24414 -0.11219 0.01279 1.67530
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 3.527093 1.307024 2.699 0.0142 *
## temperature -0.051386 0.018341 -2.802 0.0114 *
## field_check_pressure 0.001757 0.003402 0.517 0.6115
## flight_num 0.014293 0.035138 0.407 0.6887
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.565 on 19 degrees of freedom
## Multiple R-squared: 0.36, Adjusted R-squared: 0.259
## F-statistic: 3.563 on 3 and 19 DF, p-value: 0.03371
Temperature (-0.051): as temperature goes up, distress goes down. Specifically, for every 1-degree increase in temperature, the predicted distress count drops by about 0.05. Only temperature actually matters in this model. Pressure and flight number aren’t helping to predict distress in a meaningful way; their effects are likely just due to random noise.
Recreating Challenger2 for Activity 7
launch <- read.csv("challenger2.csv")
# estimate beta manually
b <- cov(launch$temperature, launch$distress_ct) / var(launch$temperature)
b
## [1] -0.03364796
# estimate alpha manually
a <- mean(launch$distress_ct) - b * mean(launch$temperature)
a
## [1] 2.814585
# calculate the correlation of launch data
r <- cov(launch$temperature, launch$distress_ct) /
(sd(launch$temperature) * sd(launch$distress_ct))
r
## [1] -0.3359996
cor(launch$temperature, launch$distress_ct)
## [1] -0.3359996
# computing the slope using correlation
r * (sd(launch$distress_ct) / sd(launch$temperature))
## [1] -0.03364796
# confirming the regression line using the lm function (not in text)
model <- lm(distress_ct ~ temperature, data = launch)
model
##
## Call:
## lm(formula = distress_ct ~ temperature, data = launch)
##
## Coefficients:
## (Intercept) temperature
## 2.81458 -0.03365
summary(model)
##
## Call:
## lm(formula = distress_ct ~ temperature, data = launch)
##
## Residuals:
## Min 1Q Median 3Q Max
## -1.0649 -0.4929 -0.2573 0.3052 1.7090
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 2.81458 1.24629 2.258 0.0322 *
## temperature -0.03365 0.01815 -1.854 0.0747 .
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.7076 on 27 degrees of freedom
## Multiple R-squared: 0.1129, Adjusted R-squared: 0.08004
## F-statistic: 3.436 on 1 and 27 DF, p-value: 0.07474
# creating a simple multiple regression function
reg <- function(y, x) {
x <- as.matrix(x)
x <- cbind(Intercept = 1, x)
b <- solve(t(x) %*% x) %*% t(x) %*% y
colnames(b) <- "estimate"
print(b)
}
# examine the launch data
str(launch)
## 'data.frame': 29 obs. of 4 variables:
## $ distress_ct : int 0 1 0 0 0 0 0 0 1 1 ...
## $ temperature : int 66 70 69 68 67 72 73 70 57 63 ...
## $ field_check_pressure: int 50 50 50 50 50 50 100 100 200 200 ...
## $ flight_num : int 1 2 3 4 5 6 7 8 9 10 ...
# test regression model with simple linear regression
reg(y = launch$distress_ct, x = launch[2])
## estimate
## Intercept 2.81458456
## temperature -0.03364796
# use regression model with multiple regression
reg(y = launch$distress_ct, x = launch[2:4])
## estimate
## Intercept 2.239817e+00
## temperature -3.124185e-02
## field_check_pressure -2.586765e-05
## flight_num 2.762455e-02
# confirming the multiple regression result using the lm function (not in text)
model <- lm(distress_ct ~ temperature + field_check_pressure + flight_num, data = launch)
model
##
## Call:
## lm(formula = distress_ct ~ temperature + field_check_pressure +
## flight_num, data = launch)
##
## Coefficients:
## (Intercept) temperature field_check_pressure
## 2.240e+00 -3.124e-02 -2.587e-05
## flight_num
## 2.762e-02
summary(model)
##
## Call:
## lm(formula = distress_ct ~ temperature + field_check_pressure +
## flight_num, data = launch)
##
## Residuals:
## Min 1Q Median 3Q Max
## -1.2744 -0.3335 -0.1657 0.2975 1.5284
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 2.240e+00 1.267e+00 1.767 0.0894 .
## temperature -3.124e-02 1.787e-02 -1.748 0.0927 .
## field_check_pressure -2.587e-05 2.383e-03 -0.011 0.9914
## flight_num 2.762e-02 1.798e-02 1.537 0.1369
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.6926 on 25 degrees of freedom
## Multiple R-squared: 0.2132, Adjusted R-squared: 0.1188
## F-statistic: 2.259 on 3 and 25 DF, p-value: 0.1063
F-statistics: Since 0.1063 is greater than 0.05, this model actually fails to prove a strong, reliable relationship between these factors and distress incidents.
Predicting Medical Expenses
## Exploring and preparing the data ----
insurance <- read.csv("insurance.csv", stringsAsFactors = TRUE)
str(insurance)
## 'data.frame': 1338 obs. of 7 variables:
## $ age : int 19 18 28 33 32 31 46 37 37 60 ...
## $ sex : Factor w/ 2 levels "female","male": 1 2 2 2 2 1 1 1 2 1 ...
## $ bmi : num 27.9 33.8 33 22.7 28.9 25.7 33.4 27.7 29.8 25.8 ...
## $ children: int 0 1 3 0 0 0 1 3 2 0 ...
## $ smoker : Factor w/ 2 levels "no","yes": 2 1 1 1 1 1 1 1 1 1 ...
## $ region : Factor w/ 4 levels "northeast","northwest",..: 4 3 3 2 2 3 3 2 1 2 ...
## $ expenses: num 16885 1726 4449 21984 3867 ...
summary(insurance$expenses)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 1122 4740 9382 13270 16640 63770
Histogram of insurance charges
hist(insurance$expenses)
Table of region
table(insurance$region)
##
## northeast northwest southeast southwest
## 324 325 364 325
Exploring relationships among features: correlation matrix
cor(insurance[c("age", "bmi", "children", "expenses")])
## age bmi children expenses
## age 1.0000000 0.10934101 0.04246900 0.29900819
## bmi 0.1093410 1.00000000 0.01264471 0.19857626
## children 0.0424690 0.01264471 1.00000000 0.06799823
## expenses 0.2990082 0.19857626 0.06799823 1.00000000
Visualizing relationships among features: scatterplot matrix
pairs(insurance[c("age", "bmi", "children", "expenses")])
Train the model and see the estimated beta coefficients
## Step 3: Training a model on the data ----
ins_model <- lm(expenses ~ age + children + bmi + sex + smoker + region,
data = insurance)
ins_model <- lm(expenses ~ ., data = insurance) # this is equivalent to above
# see the estimated beta coefficients
ins_model
##
## Call:
## lm(formula = expenses ~ ., data = insurance)
##
## Coefficients:
## (Intercept) age sexmale bmi
## -11941.6 256.8 -131.4 339.3
## children smokeryes regionnorthwest regionsoutheast
## 475.7 23847.5 -352.8 -1035.6
## regionsouthwest
## -959.3
tep 4: Evaluating model performance
summary(ins_model)
##
## Call:
## lm(formula = expenses ~ ., data = insurance)
##
## Residuals:
## Min 1Q Median 3Q Max
## -11302.7 -2850.9 -979.6 1383.9 29981.7
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -11941.6 987.8 -12.089 < 2e-16 ***
## age 256.8 11.9 21.586 < 2e-16 ***
## sexmale -131.3 332.9 -0.395 0.693255
## bmi 339.3 28.6 11.864 < 2e-16 ***
## children 475.7 137.8 3.452 0.000574 ***
## smokeryes 23847.5 413.1 57.723 < 2e-16 ***
## regionnorthwest -352.8 476.3 -0.741 0.458976
## regionsoutheast -1035.6 478.7 -2.163 0.030685 *
## regionsouthwest -959.3 477.9 -2.007 0.044921 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 6062 on 1329 degrees of freedom
## Multiple R-squared: 0.7509, Adjusted R-squared: 0.7494
## F-statistic: 500.9 on 8 and 1329 DF, p-value: < 2.2e-16
Step 5: Improving model performance Add a higher-order “age” term
insurance$age2 <- insurance$age^2
Add an indicator for BMI >= 30
insurance$bmi30 <- ifelse(insurance$bmi >= 30, 1, 0)
Create final model
ins_model2 <- lm(expenses ~ age + age2 + children + bmi + sex +
bmi30*smoker + region, data = insurance)
summary(ins_model2)
##
## Call:
## lm(formula = expenses ~ age + age2 + children + bmi + sex + bmi30 *
## smoker + region, data = insurance)
##
## Residuals:
## Min 1Q Median 3Q Max
## -17297.1 -1656.0 -1262.7 -727.8 24161.6
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 139.0053 1363.1359 0.102 0.918792
## age -32.6181 59.8250 -0.545 0.585690
## age2 3.7307 0.7463 4.999 6.54e-07 ***
## children 678.6017 105.8855 6.409 2.03e-10 ***
## bmi 119.7715 34.2796 3.494 0.000492 ***
## sexmale -496.7690 244.3713 -2.033 0.042267 *
## bmi30 -997.9355 422.9607 -2.359 0.018449 *
## smokeryes 13404.5952 439.9591 30.468 < 2e-16 ***
## regionnorthwest -279.1661 349.2826 -0.799 0.424285
## regionsoutheast -828.0345 351.6484 -2.355 0.018682 *
## regionsouthwest -1222.1619 350.5314 -3.487 0.000505 ***
## bmi30:smokeryes 19810.1534 604.6769 32.762 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4445 on 1326 degrees of freedom
## Multiple R-squared: 0.8664, Adjusted R-squared: 0.8653
## F-statistic: 781.7 on 11 and 1326 DF, p-value: < 2.2e-16
Making predictions with the regression model
insurance$pred <- predict(ins_model2, insurance)
cor(insurance$pred, insurance$expenses)
## [1] 0.9307999
plot(insurance$pred, insurance$expenses)
abline(a = 0, b = 1, col = "red", lwd = 3, lty = 2)
predict(ins_model2,
data.frame(age = 30, age2 = 30^2, children = 2,
bmi = 30, sex = "male", bmi30 = 1,
smoker = "no", region = "northeast"))
## 1
## 5973.774
predict(ins_model2,
data.frame(age = 30, age2 = 30^2, children = 2,
bmi = 30, sex = "female", bmi30 = 1,
smoker = "no", region = "northeast"))
## 1
## 6470.543
predict(ins_model2,
data.frame(age = 30, age2 = 30^2, children = 0,
bmi = 30, sex = "female", bmi30 = 1,
smoker = "no", region = "northeast"))
## 1
## 5113.34
Part 2: Regression Trees and Model Trees Understanding regression trees and model trees Example: Calculating SDR
# set up the data
tee <- c(1, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 7)
at1 <- c(1, 1, 1, 2, 2, 3, 4, 5, 5)
at2 <- c(6, 6, 7, 7, 7, 7)
bt1 <- c(1, 1, 1, 2, 2, 3, 4)
bt2 <- c(5, 5, 6, 6, 7, 7, 7, 7)
Compute the SDR
sdr_a <- sd(tee) - (length(at1) / length(tee) * sd(at1) + length(at2) / length(tee) * sd(at2))
sdr_b <- sd(tee) - (length(bt1) / length(tee) * sd(bt1) + length(bt2) / length(tee) * sd(bt2))
Compare the SDR for each split
sdr_a
## [1] 1.202815
sdr_b
## [1] 1.392751
Exercise No 3: Estimating Wine Quality Step 2: Exploring and preparing the data
wine <- read.csv("whitewines.csv")
Examine the wine data
str(wine)
## 'data.frame': 4898 obs. of 12 variables:
## $ fixed.acidity : num 6.7 5.7 5.9 5.3 6.4 7 7.9 6.6 7 6.5 ...
## $ volatile.acidity : num 0.62 0.22 0.19 0.47 0.29 0.14 0.12 0.38 0.16 0.37 ...
## $ citric.acid : num 0.24 0.2 0.26 0.1 0.21 0.41 0.49 0.28 0.3 0.33 ...
## $ residual.sugar : num 1.1 16 7.4 1.3 9.65 0.9 5.2 2.8 2.6 3.9 ...
## $ chlorides : num 0.039 0.044 0.034 0.036 0.041 0.037 0.049 0.043 0.043 0.027 ...
## $ free.sulfur.dioxide : num 6 41 33 11 36 22 33 17 34 40 ...
## $ total.sulfur.dioxide: num 62 113 123 74 119 95 152 67 90 130 ...
## $ density : num 0.993 0.999 0.995 0.991 0.993 ...
## $ pH : num 3.41 3.22 3.49 3.48 2.99 3.25 3.18 3.21 2.88 3.28 ...
## $ sulphates : num 0.32 0.46 0.42 0.54 0.34 0.43 0.47 0.47 0.47 0.39 ...
## $ alcohol : num 10.4 8.9 10.1 11.2 10.9 ...
## $ quality : int 5 6 6 4 6 6 6 6 6 7 ...
The distribution of quality ratings
hist(wine$quality)
Summary statistics of the wine data
summary(wine)
## fixed.acidity volatile.acidity citric.acid residual.sugar
## Min. : 3.800 Min. :0.0800 Min. :0.0000 Min. : 0.600
## 1st Qu.: 6.300 1st Qu.:0.2100 1st Qu.:0.2700 1st Qu.: 1.700
## Median : 6.800 Median :0.2600 Median :0.3200 Median : 5.200
## Mean : 6.855 Mean :0.2782 Mean :0.3342 Mean : 6.391
## 3rd Qu.: 7.300 3rd Qu.:0.3200 3rd Qu.:0.3900 3rd Qu.: 9.900
## Max. :14.200 Max. :1.1000 Max. :1.6600 Max. :65.800
## chlorides free.sulfur.dioxide total.sulfur.dioxide density
## Min. :0.00900 Min. : 2.00 Min. : 9.0 Min. :0.9871
## 1st Qu.:0.03600 1st Qu.: 23.00 1st Qu.:108.0 1st Qu.:0.9917
## Median :0.04300 Median : 34.00 Median :134.0 Median :0.9937
## Mean :0.04577 Mean : 35.31 Mean :138.4 Mean :0.9940
## 3rd Qu.:0.05000 3rd Qu.: 46.00 3rd Qu.:167.0 3rd Qu.:0.9961
## Max. :0.34600 Max. :289.00 Max. :440.0 Max. :1.0390
## pH sulphates alcohol quality
## Min. :2.720 Min. :0.2200 Min. : 8.00 Min. :3.000
## 1st Qu.:3.090 1st Qu.:0.4100 1st Qu.: 9.50 1st Qu.:5.000
## Median :3.180 Median :0.4700 Median :10.40 Median :6.000
## Mean :3.188 Mean :0.4898 Mean :10.51 Mean :5.878
## 3rd Qu.:3.280 3rd Qu.:0.5500 3rd Qu.:11.40 3rd Qu.:6.000
## Max. :3.820 Max. :1.0800 Max. :14.20 Max. :9.000
wine_train <- wine[1:3750, ]
wine_test <- wine[3751:4898, ]
Step 3: Training a model on the data Regression tree using rpart
library(rpart)
m.rpart <- rpart(quality ~ ., data = wine_train)
Get basic information about the tree
m.rpart
## n= 3750
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 3750 2945.53200 5.870933
## 2) alcohol< 10.85 2372 1418.86100 5.604975
## 4) volatile.acidity>=0.2275 1611 821.30730 5.432030
## 8) volatile.acidity>=0.3025 688 278.97670 5.255814 *
## 9) volatile.acidity< 0.3025 923 505.04230 5.563380 *
## 5) volatile.acidity< 0.2275 761 447.36400 5.971091 *
## 3) alcohol>=10.85 1378 1070.08200 6.328737
## 6) free.sulfur.dioxide< 10.5 84 95.55952 5.369048 *
## 7) free.sulfur.dioxide>=10.5 1294 892.13600 6.391036
## 14) alcohol< 11.76667 629 430.11130 6.173291
## 28) volatile.acidity>=0.465 11 10.72727 4.545455 *
## 29) volatile.acidity< 0.465 618 389.71680 6.202265 *
## 15) alcohol>=11.76667 665 403.99400 6.596992 *
Get more detailed information about the tree
summary(m.rpart)
## Call:
## rpart(formula = quality ~ ., data = wine_train)
## n= 3750
##
## CP nsplit rel error xerror xstd
## 1 0.15501053 0 1.0000000 1.0001169 0.02445263
## 2 0.05098911 1 0.8449895 0.8481575 0.02337580
## 3 0.02796998 2 0.7940004 0.8056471 0.02267098
## 4 0.01970128 3 0.7660304 0.7875000 0.02177570
## 5 0.01265926 4 0.7463291 0.7603867 0.02068508
## 6 0.01007193 5 0.7336698 0.7527081 0.02056573
## 7 0.01000000 6 0.7235979 0.7459574 0.02057589
##
## Variable importance
## alcohol density volatile.acidity
## 34 21 15
## chlorides total.sulfur.dioxide free.sulfur.dioxide
## 11 7 6
## residual.sugar sulphates citric.acid
## 3 1 1
##
## Node number 1: 3750 observations, complexity param=0.1550105
## mean=5.870933, MSE=0.7854751
## left son=2 (2372 obs) right son=3 (1378 obs)
## Primary splits:
## alcohol < 10.85 to the left, improve=0.15501050, (0 missing)
## density < 0.992035 to the right, improve=0.10915940, (0 missing)
## chlorides < 0.0395 to the right, improve=0.07682258, (0 missing)
## total.sulfur.dioxide < 158.5 to the right, improve=0.04089663, (0 missing)
## citric.acid < 0.235 to the left, improve=0.03636458, (0 missing)
## Surrogate splits:
## density < 0.991995 to the right, agree=0.869, adj=0.644, (0 split)
## chlorides < 0.0375 to the right, agree=0.757, adj=0.339, (0 split)
## total.sulfur.dioxide < 103.5 to the right, agree=0.690, adj=0.155, (0 split)
## residual.sugar < 5.375 to the right, agree=0.667, adj=0.094, (0 split)
## sulphates < 0.345 to the right, agree=0.647, adj=0.038, (0 split)
##
## Node number 2: 2372 observations, complexity param=0.05098911
## mean=5.604975, MSE=0.5981709
## left son=4 (1611 obs) right son=5 (761 obs)
## Primary splits:
## volatile.acidity < 0.2275 to the right, improve=0.10585250, (0 missing)
## free.sulfur.dioxide < 13.5 to the left, improve=0.03390500, (0 missing)
## citric.acid < 0.235 to the left, improve=0.03204075, (0 missing)
## alcohol < 10.11667 to the left, improve=0.03136524, (0 missing)
## chlorides < 0.0585 to the right, improve=0.01633599, (0 missing)
## Surrogate splits:
## pH < 3.485 to the left, agree=0.694, adj=0.047, (0 split)
## sulphates < 0.755 to the left, agree=0.685, adj=0.020, (0 split)
## total.sulfur.dioxide < 105.5 to the right, agree=0.683, adj=0.011, (0 split)
## residual.sugar < 0.75 to the right, agree=0.681, adj=0.007, (0 split)
## chlorides < 0.0285 to the right, agree=0.680, adj=0.003, (0 split)
##
## Node number 3: 1378 observations, complexity param=0.02796998
## mean=6.328737, MSE=0.7765472
## left son=6 (84 obs) right son=7 (1294 obs)
## Primary splits:
## free.sulfur.dioxide < 10.5 to the left, improve=0.07699080, (0 missing)
## alcohol < 11.76667 to the left, improve=0.06210660, (0 missing)
## total.sulfur.dioxide < 67.5 to the left, improve=0.04438619, (0 missing)
## residual.sugar < 1.375 to the left, improve=0.02905351, (0 missing)
## fixed.acidity < 7.35 to the right, improve=0.02613259, (0 missing)
## Surrogate splits:
## total.sulfur.dioxide < 53.5 to the left, agree=0.952, adj=0.214, (0 split)
## volatile.acidity < 0.875 to the right, agree=0.940, adj=0.024, (0 split)
##
## Node number 4: 1611 observations, complexity param=0.01265926
## mean=5.43203, MSE=0.5098121
## left son=8 (688 obs) right son=9 (923 obs)
## Primary splits:
## volatile.acidity < 0.3025 to the right, improve=0.04540111, (0 missing)
## alcohol < 10.05 to the left, improve=0.03874403, (0 missing)
## free.sulfur.dioxide < 13.5 to the left, improve=0.03338886, (0 missing)
## chlorides < 0.0495 to the right, improve=0.02574623, (0 missing)
## citric.acid < 0.195 to the left, improve=0.02327981, (0 missing)
## Surrogate splits:
## citric.acid < 0.215 to the left, agree=0.633, adj=0.141, (0 split)
## free.sulfur.dioxide < 20.5 to the left, agree=0.600, adj=0.063, (0 split)
## chlorides < 0.0595 to the right, agree=0.593, adj=0.047, (0 split)
## residual.sugar < 1.15 to the left, agree=0.583, adj=0.023, (0 split)
## total.sulfur.dioxide < 219.25 to the right, agree=0.582, adj=0.022, (0 split)
##
## Node number 5: 761 observations
## mean=5.971091, MSE=0.5878633
##
## Node number 6: 84 observations
## mean=5.369048, MSE=1.137613
##
## Node number 7: 1294 observations, complexity param=0.01970128
## mean=6.391036, MSE=0.6894405
## left son=14 (629 obs) right son=15 (665 obs)
## Primary splits:
## alcohol < 11.76667 to the left, improve=0.06504696, (0 missing)
## chlorides < 0.0395 to the right, improve=0.02758705, (0 missing)
## fixed.acidity < 7.35 to the right, improve=0.02750932, (0 missing)
## pH < 3.055 to the left, improve=0.02307356, (0 missing)
## total.sulfur.dioxide < 191.5 to the right, improve=0.02186818, (0 missing)
## Surrogate splits:
## density < 0.990885 to the right, agree=0.720, adj=0.424, (0 split)
## volatile.acidity < 0.2675 to the left, agree=0.637, adj=0.253, (0 split)
## chlorides < 0.0365 to the right, agree=0.630, adj=0.238, (0 split)
## residual.sugar < 1.475 to the left, agree=0.575, adj=0.126, (0 split)
## total.sulfur.dioxide < 128.5 to the right, agree=0.574, adj=0.124, (0 split)
##
## Node number 8: 688 observations
## mean=5.255814, MSE=0.4054895
##
## Node number 9: 923 observations
## mean=5.56338, MSE=0.5471747
##
## Node number 14: 629 observations, complexity param=0.01007193
## mean=6.173291, MSE=0.6838017
## left son=28 (11 obs) right son=29 (618 obs)
## Primary splits:
## volatile.acidity < 0.465 to the right, improve=0.06897561, (0 missing)
## total.sulfur.dioxide < 200 to the right, improve=0.04223066, (0 missing)
## residual.sugar < 0.975 to the left, improve=0.03061714, (0 missing)
## fixed.acidity < 7.35 to the right, improve=0.02978501, (0 missing)
## sulphates < 0.575 to the left, improve=0.02165970, (0 missing)
## Surrogate splits:
## citric.acid < 0.045 to the left, agree=0.986, adj=0.182, (0 split)
## total.sulfur.dioxide < 279.25 to the right, agree=0.986, adj=0.182, (0 split)
##
## Node number 15: 665 observations
## mean=6.596992, MSE=0.6075098
##
## Node number 28: 11 observations
## mean=4.545455, MSE=0.9752066
##
## Node number 29: 618 observations
## mean=6.202265, MSE=0.6306098
Installing packages
install.packages("rpart.plot")
## Installing package into '/cloud/lib/x86_64-pc-linux-gnu-library/4.5'
## (as 'lib' is unspecified)
Use the rpart.plot package to create a visualization
library(rpart.plot)
A basic decision tree diagram
rpart.plot(m.rpart, digits = 3)
A few adjustments to the diagram
rpart.plot(m.rpart, digits = 4, fallen.leaves = TRUE, type = 3, extra = 101)
Step 4: Evaluate model performance
# generate predictions for the testing dataset
p.rpart <- predict(m.rpart, wine_test)
Compare the distribution of predicted values vs. actual values
summary(p.rpart)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 4.545 5.563 5.971 5.893 6.202 6.597
summary(wine_test$quality)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 3.000 5.000 6.000 5.901 6.000 9.000
Compare the correlation
cor(p.rpart, wine_test$quality)
## [1] 0.5369525
Function to calculate the mean absolute error
MAE <- function(actual, predicted) {
mean(abs(actual - predicted))
}
Mean absolute error between predicted and actual values
MAE(p.rpart, wine_test$quality)
## [1] 0.5872652
Mean absolute error between actual values and mean value
mean(wine_train$quality) # result = 5.87
## [1] 5.870933
MAE(5.87, wine_test$quality)
## [1] 0.6722474
Step 5: Improving model performance
install.packages("plyr")
## Installing package into '/cloud/lib/x86_64-pc-linux-gnu-library/4.5'
## (as 'lib' is unspecified)
install.packages("Cubist")
## Installing package into '/cloud/lib/x86_64-pc-linux-gnu-library/4.5'
## (as 'lib' is unspecified)
Train a Cubist Model Tree
library(Cubist)
## Loading required package: lattice
m.cubist <- cubist(x = wine_train[-12], y = wine_train$quality)
Display basic information about the model tree
m.cubist
##
## Call:
## cubist.default(x = wine_train[-12], y = wine_train$quality)
##
## Number of samples: 3750
## Number of predictors: 11
##
## Number of committees: 1
## Number of rules: 25
Display the tree itself
summary(m.cubist)
##
## Call:
## cubist.default(x = wine_train[-12], y = wine_train$quality)
##
##
## Cubist [Release 2.07 GPL Edition] Tue Feb 3 01:11:20 2026
## ---------------------------------
##
## Target attribute `outcome'
##
## Read 3750 cases (12 attributes) from undefined.data
##
## Model:
##
## Rule 1: [21 cases, mean 5.0, range 4 to 6, est err 0.5]
##
## if
## free.sulfur.dioxide > 30
## total.sulfur.dioxide > 195
## total.sulfur.dioxide <= 235
## sulphates > 0.64
## alcohol > 9.1
## then
## outcome = 573.6 + 0.0478 total.sulfur.dioxide - 573 density
## - 0.788 alcohol + 0.186 residual.sugar - 4.73 volatile.acidity
##
## Rule 2: [28 cases, mean 5.0, range 4 to 8, est err 0.7]
##
## if
## volatile.acidity > 0.31
## citric.acid <= 0.36
## residual.sugar <= 1.45
## total.sulfur.dioxide <= 97
## alcohol > 9.1
## then
## outcome = 168.2 + 4.75 citric.acid + 0.0123 total.sulfur.dioxide
## - 170 density + 0.057 residual.sugar - 6.4 chlorides + 0.84 pH
## + 0.14 fixed.acidity
##
## Rule 3: [171 cases, mean 5.1, range 3 to 6, est err 0.3]
##
## if
## volatile.acidity > 0.205
## chlorides <= 0.054
## density <= 0.99839
## alcohol <= 9.1
## then
## outcome = 147.4 - 144 density + 0.08 residual.sugar + 0.117 alcohol
## - 0.87 volatile.acidity - 0.09 pH - 0.01 fixed.acidity
##
## Rule 4: [37 cases, mean 5.3, range 3 to 6, est err 0.5]
##
## if
## free.sulfur.dioxide > 30
## total.sulfur.dioxide > 235
## alcohol > 9.1
## then
## outcome = 19.5 - 0.013 total.sulfur.dioxide - 2.7 volatile.acidity
## - 10 density + 0.005 residual.sugar + 0.008 alcohol
##
## Rule 5: [64 cases, mean 5.3, range 5 to 6, est err 0.3]
##
## if
## volatile.acidity > 0.205
## residual.sugar > 17.85
## then
## outcome = -23.6 + 0.233 alcohol - 5.2 chlorides - 0.75 citric.acid
## + 28 density - 0.81 volatile.acidity - 0.19 pH
## - 0.002 residual.sugar
##
## Rule 6: [56 cases, mean 5.3, range 4 to 7, est err 0.6]
##
## if
## fixed.acidity <= 7.1
## volatile.acidity > 0.205
## chlorides > 0.054
## density <= 0.99839
## alcohol <= 9.1
## then
## outcome = 40.6 + 0.374 alcohol - 1.62 volatile.acidity
## + 0.026 residual.sugar - 38 density - 0.21 pH
## - 0.01 fixed.acidity
##
## Rule 7: [337 cases, mean 5.3, range 3 to 7, est err 0.4]
##
## if
## fixed.acidity <= 7.8
## volatile.acidity > 0.305
## chlorides <= 0.09
## free.sulfur.dioxide <= 82.5
## total.sulfur.dioxide > 130
## total.sulfur.dioxide <= 235
## sulphates <= 0.64
## alcohol <= 10.4
## then
## outcome = -32.1 + 0.233 alcohol - 9.7 chlorides
## + 0.0038 total.sulfur.dioxide - 0.0081 free.sulfur.dioxide
## + 35 density + 0.81 volatile.acidity
##
## Rule 8: [30 cases, mean 5.5, range 3 to 7, est err 0.5]
##
## if
## fixed.acidity > 7.1
## volatile.acidity > 0.205
## chlorides > 0.054
## density <= 0.99839
## alcohol <= 9.1
## then
## outcome = 244 - 1.56 fixed.acidity - 228 density
## + 0.0252 free.sulfur.dioxide - 7.3 chlorides
## - 0.19 volatile.acidity + 0.003 residual.sugar
##
## Rule 9: [98 cases, mean 5.5, range 4 to 8, est err 0.5]
##
## if
## volatile.acidity > 0.155
## chlorides > 0.09
## total.sulfur.dioxide <= 235
## sulphates <= 0.64
## then
## outcome = 55.9 - 3.85 volatile.acidity - 52 density
## + 0.023 residual.sugar + 0.092 alcohol + 0.35 pH
## + 0.05 fixed.acidity + 0.3 sulphates
## + 0.001 free.sulfur.dioxide
##
## Rule 10: [446 cases, mean 5.6, range 4 to 8, est err 0.5]
##
## if
## fixed.acidity <= 7.8
## volatile.acidity > 0.155
## volatile.acidity <= 0.305
## chlorides <= 0.09
## free.sulfur.dioxide <= 82.5
## total.sulfur.dioxide > 130
## total.sulfur.dioxide <= 235
## sulphates <= 0.64
## alcohol > 9.1
## alcohol <= 10.4
## then
## outcome = 15.1 + 0.35 alcohol - 3.09 volatile.acidity - 14.7 chlorides
## + 1.16 sulphates - 0.0022 total.sulfur.dioxide
## + 0.11 fixed.acidity + 0.45 pH + 0.5 citric.acid - 14 density
## + 0.006 residual.sugar
##
## Rule 11: [31 cases, mean 5.6, range 3 to 8, est err 0.8]
##
## if
## volatile.acidity > 0.31
## citric.acid > 0.36
## free.sulfur.dioxide <= 30
## total.sulfur.dioxide <= 97
## then
## outcome = 3.2 + 0.0584 total.sulfur.dioxide + 7.77 volatile.acidity
## + 0.328 alcohol - 9 density + 0.003 residual.sugar
##
## Rule 12: [20 cases, mean 5.7, range 3 to 8, est err 0.9]
##
## if
## free.sulfur.dioxide > 82.5
## total.sulfur.dioxide <= 235
## sulphates <= 0.64
## alcohol > 9.1
## then
## outcome = -8.9 + 109.3 chlorides + 0.948 alcohol
##
## Rule 13: [331 cases, mean 5.8, range 4 to 8, est err 0.5]
##
## if
## volatile.acidity > 0.31
## free.sulfur.dioxide <= 30
## total.sulfur.dioxide > 97
## alcohol > 9.1
## then
## outcome = 89.8 + 0.0234 free.sulfur.dioxide + 0.324 alcohol
## + 0.07 residual.sugar - 90 density - 1.47 volatile.acidity
## + 0.48 pH
##
## Rule 14: [116 cases, mean 5.8, range 3 to 8, est err 0.6]
##
## if
## fixed.acidity > 7.8
## volatile.acidity > 0.155
## free.sulfur.dioxide > 30
## total.sulfur.dioxide > 130
## total.sulfur.dioxide <= 235
## sulphates <= 0.64
## alcohol > 9.1
## then
## outcome = 6 + 0.346 alcohol - 0.41 fixed.acidity - 1.69 volatile.acidity
## - 2.9 chlorides + 0.19 sulphates + 0.07 pH
##
## Rule 15: [115 cases, mean 5.8, range 4 to 7, est err 0.5]
##
## if
## volatile.acidity > 0.205
## residual.sugar <= 17.85
## density > 0.99839
## alcohol <= 9.1
## then
## outcome = -110.2 + 120 density - 3.46 volatile.acidity - 0.97 pH
## - 0.022 residual.sugar + 0.088 alcohol - 0.6 citric.acid
## - 0.01 fixed.acidity
##
## Rule 16: [986 cases, mean 5.9, range 3 to 9, est err 0.6]
##
## if
## volatile.acidity <= 0.31
## free.sulfur.dioxide <= 30
## alcohol > 9.1
## then
## outcome = 280.4 - 282 density + 0.128 residual.sugar
## + 0.0264 free.sulfur.dioxide - 3 volatile.acidity + 1.2 pH
## + 0.65 citric.acid + 0.09 fixed.acidity + 0.56 sulphates
## + 0.015 alcohol
##
## Rule 17: [49 cases, mean 6.0, range 5 to 8, est err 0.5]
##
## if
## volatile.acidity > 0.155
## residual.sugar > 8.8
## free.sulfur.dioxide > 30
## total.sulfur.dioxide <= 130
## pH <= 3.26
## alcohol > 9.1
## then
## outcome = 173.5 - 169 density + 0.055 alcohol + 0.38 sulphates
## + 0.002 residual.sugar
##
## Rule 18: [114 cases, mean 6.1, range 3 to 9, est err 0.6]
##
## if
## volatile.acidity > 0.31
## citric.acid <= 0.36
## residual.sugar > 1.45
## total.sulfur.dioxide <= 97
## alcohol > 9.1
## then
## outcome = 302.3 - 305 density + 0.0128 total.sulfur.dioxide
## + 0.096 residual.sugar + 1.94 citric.acid + 1.05 pH
## + 0.17 fixed.acidity - 6.7 chlorides
## + 0.0022 free.sulfur.dioxide - 0.21 volatile.acidity
## + 0.013 alcohol + 0.09 sulphates
##
## Rule 19: [145 cases, mean 6.1, range 5 to 8, est err 0.6]
##
## if
## volatile.acidity > 0.155
## free.sulfur.dioxide > 30
## total.sulfur.dioxide <= 195
## sulphates > 0.64
## then
## outcome = 206 - 209 density + 0.069 residual.sugar + 0.38 fixed.acidity
## + 2.79 sulphates + 0.0155 free.sulfur.dioxide
## - 0.0051 total.sulfur.dioxide - 1.71 citric.acid + 1.04 pH
##
## Rule 20: [555 cases, mean 6.1, range 3 to 9, est err 0.6]
##
## if
## total.sulfur.dioxide > 130
## total.sulfur.dioxide <= 235
## sulphates <= 0.64
## alcohol > 10.4
## then
## outcome = 108 + 0.276 alcohol - 109 density + 0.05 residual.sugar
## + 0.77 pH - 1.02 volatile.acidity - 4.2 chlorides
## + 0.78 sulphates + 0.08 fixed.acidity
## + 0.0016 free.sulfur.dioxide - 0.0003 total.sulfur.dioxide
##
## Rule 21: [73 cases, mean 6.2, range 4 to 8, est err 0.4]
##
## if
## volatile.acidity > 0.155
## citric.acid <= 0.28
## residual.sugar <= 8.8
## free.sulfur.dioxide > 30
## total.sulfur.dioxide <= 130
## pH <= 3.26
## sulphates <= 0.64
## alcohol > 9.1
## then
## outcome = 4.2 + 0.147 residual.sugar + 0.47 alcohol + 3.75 sulphates
## - 2.5 volatile.acidity - 5 density
##
## Rule 22: [244 cases, mean 6.3, range 4 to 8, est err 0.6]
##
## if
## citric.acid > 0.28
## residual.sugar <= 8.8
## free.sulfur.dioxide > 30
## total.sulfur.dioxide <= 130
## pH <= 3.26
## then
## outcome = 40.1 + 0.278 alcohol + 1.3 sulphates - 39 density
## + 0.017 residual.sugar + 0.001 total.sulfur.dioxide + 0.17 pH
## + 0.03 fixed.acidity
##
## Rule 23: [106 cases, mean 6.3, range 4 to 8, est err 0.6]
##
## if
## volatile.acidity <= 0.155
## free.sulfur.dioxide > 30
## then
## outcome = 139.1 - 138 density + 0.058 residual.sugar + 0.71 pH
## + 0.92 sulphates + 0.11 fixed.acidity - 0.73 volatile.acidity
## + 0.055 alcohol - 0.0012 total.sulfur.dioxide
## + 0.0007 free.sulfur.dioxide
##
## Rule 24: [137 cases, mean 6.5, range 4 to 9, est err 0.6]
##
## if
## volatile.acidity > 0.155
## free.sulfur.dioxide > 30
## total.sulfur.dioxide <= 130
## pH > 3.26
## sulphates <= 0.64
## alcohol > 9.1
## then
## outcome = 114.2 + 0.0142 total.sulfur.dioxide - 107 density
## - 11.8 chlorides - 1.57 pH + 0.124 alcohol + 1.21 sulphates
## + 1.16 volatile.acidity + 0.021 residual.sugar
## + 0.04 fixed.acidity
##
## Rule 25: [92 cases, mean 6.5, range 4 to 8, est err 0.6]
##
## if
## volatile.acidity <= 0.205
## alcohol <= 9.1
## then
## outcome = -200.7 + 210 density + 5.88 volatile.acidity + 23.9 chlorides
## - 2.83 citric.acid - 1.17 pH
##
##
## Evaluation on training data (3750 cases):
##
## Average |error| 0.5
## Relative |error| 0.67
## Correlation coefficient 0.66
##
##
## Attribute usage:
## Conds Model
##
## 84% 93% alcohol
## 80% 89% volatile.acidity
## 70% 61% free.sulfur.dioxide
## 63% 50% total.sulfur.dioxide
## 44% 70% sulphates
## 26% 44% chlorides
## 22% 76% fixed.acidity
## 16% 87% residual.sugar
## 11% 86% pH
## 11% 45% citric.acid
## 8% 97% density
##
##
## Time: 0.2 secs
Generate predictions for the model
p.cubist <- predict(m.cubist, wine_test)
Summary statistics about the predictions
summary(p.cubist)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 3.677 5.416 5.906 5.848 6.238 7.393
Correlation between the predicted and true values
cor(p.cubist, wine_test$quality)
## [1] 0.6201015
Mean absolute error of predicted and true values
# (uses a custom function defined above)
MAE(wine_test$quality, p.cubist)
## [1] 0.5339725