First things first - regression modelling isn’t just one algorithm, but a family within generalised linear models (GLM). We’ll discuss only the most basic ones here.
We essentially have a dependent variable \(y\) which we wish to model as a function of one or more independent variables, such that we can make predictions for \(y\).
Such models are used for:
The most basic, simple linear regression, uses a formula like \(y = \alpha + \beta x\). This is a simple linear regression model. Multiple regression would be more like \(y = \alpha + \beta_1 x_1 + \beta_2 x_2 + \beta_3 x_3\) (here y is a function of multiple independent variables).
We typically use ordrinary least squares estimation (OLS) to estimate \(\alpha\) and \(\beta\). This minimises the residuals (distance between an actual data-point and the predicted value on the regression line).
For simple linear regressions, after a bit of arithmetic we’d get:
\(\alpha = \hat{y} - \beta \hat{x}\) and \(\beta = \frac{\sum (x_i - \hat{x}) (y_i - \hat{y})}{\sum(x_i - \hat{x})^2} == \frac{Cov(x, y)}{Var(x)}\)
And for multiple linear regression model above, we could just re-write that formula mentioned earlier as
\(y = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + ... + \beta_n x_n + \epsilon\)
Or in condensed, matrix form:
\(Y = X \beta + \epsilon\)
where Y is a column vector of i elements (each example is an i), X is an i x n matrix (first column, \(X_0\) is all-ones for \(\beta_0\)), and \(\epsilon\) is another column vector of i elements. OLS can be applied here too to give:
\(\hat{\beta} = (X^{T}X)^{-1} X^{T} Y\)
All of these are readily solved with R, so let’s go through a couple of quick linear regression examples of the Challenger disaster.
Data is available here. This is information which was available to the Challenger mission the night before the launch. The shuttle had 6 O-rings, none of which were allowed to fail. These O-rings are sensitive to temperature, becoming brittle with the cold. They were concerned that temperatures in the morning would be around 30 degrees F, but there was a lot of political pressure to carry out the launch.
# Alright, let's do a simple linear regression example first:
launch = read.csv("D://dev//R//mlwr/chap6-regression/challenger.csv")
str(launch)
## 'data.frame': 23 obs. of 5 variables:
## $ o_ring_ct : int 6 6 6 6 6 6 6 6 6 6 ...
## $ distress_ct: int 0 1 0 0 0 0 0 0 1 1 ...
## $ temperature: int 66 70 69 68 67 72 73 70 57 63 ...
## $ pressure : int 50 50 50 50 50 50 100 100 200 200 ...
## $ launch_id : int 1 2 3 4 5 6 7 8 9 10 ...
# o_ring_ct is the total number of rings present for each launch
# distress_ct is the number of distressed O-rings
# The str() results show that it's all integer data, let's see more:
summary(launch)
## o_ring_ct distress_ct temperature pressure
## Min. :6 Min. :0.0000 Min. :53.00 Min. : 50.0
## 1st Qu.:6 1st Qu.:0.0000 1st Qu.:67.00 1st Qu.: 75.0
## Median :6 Median :0.0000 Median :70.00 Median :200.0
## Mean :6 Mean :0.3043 Mean :69.57 Mean :152.2
## 3rd Qu.:6 3rd Qu.:0.5000 3rd Qu.:75.00 3rd Qu.:200.0
## Max. :6 Max. :2.0000 Max. :81.00 Max. :200.0
## launch_id
## Min. : 1.0
## 1st Qu.: 6.5
## Median :12.0
## Mean :12.0
## 3rd Qu.:17.5
## Max. :23.0
# we'll try to model distress_ct as a function of temperature
# controversial, lowest temperature previously observed was 53. Let's plot this:
library(ggplot2)
ggplot(data=launch,
aes(x=temperature, y=distress_ct)) +
geom_point() +
geom_smooth(method='lm', formula='y ~ x') # this draws a regression line
# That doesn't look promising...
# What's correlation between distress_ct and temperature?
cor(launch$distress_ct, launch$temperature)
## [1] -0.725671
# Ok, so it's pretty strong, and negative (i.e. drop temperature -> increase distress_ct)
# Linear model: y = alpha + beta * x
# from the notes above, recall that:
# alpha = y-mean - beta*x-mean
# beta = cov(x, y)/var(x)
beta = cov(launch$temperature, launch$distress_ct) / var(launch$temperature)
alpha = mean(launch$distress_ct) - beta * mean(launch$temperature)
c(alpha, beta)
## [1] 4.30158730 -0.05746032
# R can do a linear-regression for us using lm() (for linear-model).
# Let's compare our alpha/beta with R's lm() method:
slm = lm(data=launch, formula = distress_ct ~ temperature)
slm
##
## Call:
## lm(formula = distress_ct ~ temperature, data = launch)
##
## Coefficients:
## (Intercept) temperature
## 4.30159 -0.05746
# As if by magic!
# So what's the predicted value then for our expected 30F temperature?
predicted_ring_fails = predict(slm, data.frame(temperature=30))
predicted_ring_fails
## 1
## 2.577778
# so we'd expect 2.6 rings to fail. Not good.
What if we wanted to discover other correlations within our dataset? Welcome to GGally::ggpairs - a great tool to draw a scatterplot matrix:
#install.packages("GGally")
library(GGally)
ggpairs(launch)
## Warning in cor(x, y, method = method, use = use): the standard deviation is
## zero
## Warning in cor(x, y, method = method, use = use): the standard deviation is
## zero
## Warning in cor(x, y, method = method, use = use): the standard deviation is
## zero
## Warning in cor(x, y, method = method, use = use): the standard deviation is
## zero
Here we can see that pressure also has a moderate correlation with distress_ct - not surprising.
Let’s finish with an example of multiple linear regression. We’ll do this two ways, one using our own very basic hand-crafted function, and then using lm(). For the former, recall that \(\hat{\beta} = (X^{T}X)^{-1} X^{T} Y\)
# This function returns the beta coefficients for a multiple linear regression
# @param y: vector of dependent variable values
# @param x: frame of independent variable values
solve_multiple_regression = function(y, x) {
x = as.matrix(x) # coerces x into matrix form (from, say, data.frame or vector)
x = cbind(Intercept = 1, x) # adds a feature to x called Intercept, set to 1, in every record of frame x
# Lots of new functions:
# t(x) returns the transpose of x
# %*% performs matrix multiplication
# solve(x) returns the inverse of x
# X^T X X^T Y
solve( t(x) %*% x ) %*% t(x) %*% y
}
# Let's give her a whirl - let's test it by repeating that simple linear regression above:
solve_multiple_regression(y = launch$distress_ct,
x = launch[c("temperature")])
## [,1]
## Intercept 4.30158730
## temperature -0.05746032
# Identical to before!
# Let's try a multiple regression now:
solve_multiple_regression(y = launch$distress_ct,
x = launch[c("temperature", "pressure")])
## [,1]
## Intercept 4.045202762
## temperature -0.058247321
## pressure 0.002044586
The above function gives us some useful info, but R’s are far more informative:
# Finally, let's use R's lm():
mlmodel = lm(data=launch,
formula = distress_ct ~ temperature + pressure)
mlmodel
##
## Call:
## lm(formula = distress_ct ~ temperature + pressure, data = launch)
##
## Coefficients:
## (Intercept) temperature pressure
## 4.045203 -0.058247 0.002045
# Good, that gives us the same coefficients as our own function. Let's see what else:
summary(mlmodel)
##
## Call:
## lm(formula = distress_ct ~ temperature + pressure, data = launch)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.55155 -0.17948 -0.07578 0.11829 0.92988
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 4.045203 0.807267 5.011 6.70e-05 ***
## temperature -0.058247 0.011363 -5.126 5.15e-05 ***
## pressure 0.002045 0.001175 1.739 0.0973 .
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.3758 on 20 degrees of freedom
## Multiple R-squared: 0.5888, Adjusted R-squared: 0.5477
## F-statistic: 14.32 on 2 and 20 DF, p-value: 0.0001382
As we can see, R’s lm() method gives us all kind of extra info too, including significance levels and \(R^2\)
The data is available here.
ins = read.csv("d://dev//R//mlwr/chap6-regression/insurance.csv")
str(ins)
## 'data.frame': 1338 obs. of 7 variables:
## $ age : int 19 18 28 33 32 31 46 37 37 60 ...
## $ sex : Factor w/ 2 levels "female","male": 1 2 2 2 2 1 1 1 2 1 ...
## $ bmi : num 27.9 33.8 33 22.7 28.9 ...
## $ children: int 0 1 3 0 0 0 1 3 2 0 ...
## $ smoker : Factor w/ 2 levels "no","yes": 2 1 1 1 1 1 1 1 1 1 ...
## $ region : Factor w/ 4 levels "northeast","northwest",..: 4 3 3 2 2 3 3 2 1 2 ...
## $ charges : num 16885 1726 4449 21984 3867 ...
# So a few of these are factor, we'll see how R's regression tools deal with them
summary(ins)
## age sex bmi children smoker
## Min. :18.00 female:662 Min. :15.96 Min. :0.000 no :1064
## 1st Qu.:27.00 male :676 1st Qu.:26.30 1st Qu.:0.000 yes: 274
## Median :39.00 Median :30.40 Median :1.000
## Mean :39.21 Mean :30.66 Mean :1.095
## 3rd Qu.:51.00 3rd Qu.:34.69 3rd Qu.:2.000
## Max. :64.00 Max. :53.13 Max. :5.000
## region charges
## northeast:324 Min. : 1122
## northwest:325 1st Qu.: 4740
## southeast:364 Median : 9382
## southwest:325 Mean :13270
## 3rd Qu.:16640
## Max. :63770
# Roughly even split of boys/girls and regional variation
# Mostly non-smokers
# Wild BMI extremes
# good age range
# Overall looks like a pretty normal sample of the population as a whole
# we want to be able to tell how charges will vary as a function of other features.
# Let's see how charges are distributed
ggplot(data=ins,
aes(x=charges)) +
geom_histogram(binwidth=1000)
# pretty heavily skewed right. Let's look at logs
ggplot(data=ins,
aes(x=charges)) +
geom_histogram() +
scale_x_log10()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
# we can see a bit more structure in there
# Let's run this through lm()
ins_mod = lm(data=ins, formula = charges ~ .)
summary(ins_mod)
##
## Call:
## lm(formula = charges ~ ., data = ins)
##
## Residuals:
## Min 1Q Median 3Q Max
## -11304.9 -2848.1 -982.1 1393.9 29992.8
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -11938.5 987.8 -12.086 < 2e-16 ***
## age 256.9 11.9 21.587 < 2e-16 ***
## sexmale -131.3 332.9 -0.394 0.693348
## bmi 339.2 28.6 11.860 < 2e-16 ***
## children 475.5 137.8 3.451 0.000577 ***
## smokeryes 23848.5 413.1 57.723 < 2e-16 ***
## regionnorthwest -353.0 476.3 -0.741 0.458769
## regionsoutheast -1035.0 478.7 -2.162 0.030782 *
## regionsouthwest -960.0 477.9 -2.009 0.044765 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 6062 on 1329 degrees of freedom
## Multiple R-squared: 0.7509, Adjusted R-squared: 0.7494
## F-statistic: 500.8 on 8 and 1329 DF, p-value: < 2.2e-16
Let’s stop for a moment and examine the output above:
The problem with linear regression is that it requires the user to know a bit about the field to get the most out of it, unlike classifications, which did it all on their own.
We can improve this further.
It’s unlikely that the charges are a linear function of age - the older you get, the more likely it is you’ll have increased healthcare costs. We can probably improve the model by using age squared:
ins$age2 = ins$age^2
ins_mod_age2 = lm(data=ins, formula = charges ~ .)
summary(ins_mod_age2)
##
## Call:
## lm(formula = charges ~ ., data = ins)
##
## Residuals:
## Min 1Q Median 3Q Max
## -11665.1 -2855.8 -944.1 1295.9 30826.0
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -6596.665 1689.444 -3.905 9.91e-05 ***
## age -54.575 80.991 -0.674 0.500532
## sexmale -138.428 331.197 -0.418 0.676043
## bmi 335.211 28.467 11.775 < 2e-16 ***
## children 642.024 143.617 4.470 8.47e-06 ***
## smokeryes 23859.745 410.988 58.055 < 2e-16 ***
## regionnorthwest -367.812 473.783 -0.776 0.437692
## regionsoutheast -1031.503 476.172 -2.166 0.030470 *
## regionsouthwest -957.546 475.417 -2.014 0.044198 *
## age2 3.927 1.010 3.887 0.000107 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 6030 on 1328 degrees of freedom
## Multiple R-squared: 0.7537, Adjusted R-squared: 0.752
## F-statistic: 451.6 on 9 and 1328 DF, p-value: < 2.2e-16
So slight improvement in R2, residual errors not moved much either.
We can also state safely that BMI values up to 25 are ok, so it’s only really BMI values above 25 that are bad, and over 30 that are really bad. Let’s make that improvement
ins$bmi_severity = ifelse(ins$bmi > 30, 1, 0)
ins_mod_age2_bmi = lm(data=ins, formula = charges ~ .)
summary(ins_mod_age2_bmi)
##
## Call:
## lm(formula = charges ~ ., data = ins)
##
## Residuals:
## Min 1Q Median 3Q Max
## -12525.0 -3381.3 160.9 1311.3 29258.3
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -2852.423 1825.191 -1.563 0.118336
## age -26.238 80.416 -0.326 0.744261
## sexmale -159.152 328.097 -0.485 0.627703
## bmi 148.123 46.045 3.217 0.001327 **
## children 627.304 142.291 4.409 1.12e-05 ***
## smokeryes 23864.986 407.112 58.620 < 2e-16 ***
## regionnorthwest -401.016 469.358 -0.854 0.393041
## regionsoutheast -884.246 472.549 -1.871 0.061534 .
## regionsouthwest -929.552 470.963 -1.974 0.048620 *
## age2 3.573 1.003 3.562 0.000382 ***
## bmi_severity 2813.900 547.483 5.140 3.16e-07 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 5973 on 1327 degrees of freedom
## Multiple R-squared: 0.7585, Adjusted R-squared: 0.7567
## F-statistic: 416.8 on 10 and 1327 DF, p-value: < 2.2e-16
Finally, it’s likely that being overweight combined with smoking is a combination worse than its parts. This is called an interaction. Let’s add that in too:
ins_mod_fatsmokers = lm(data=ins, formula =
charges ~ age2 + age + sex + bmi + children + region +
smoker*bmi_severity) # use * to represent interaction
summary(ins_mod_fatsmokers)
##
## Call:
## lm(formula = charges ~ age2 + age + sex + bmi + children + region +
## smoker * bmi_severity, data = ins)
##
## Residuals:
## Min 1Q Median 3Q Max
## -4260.3 -1644.6 -1272.7 -784.7 24192.7
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 69.2494 1353.2349 0.051 0.959195
## age2 3.5978 0.7422 4.847 1.40e-06 ***
## age -21.6786 59.4956 -0.364 0.715638
## sexmale -475.6760 242.9293 -1.958 0.050430 .
## bmi 114.2920 34.0816 3.353 0.000821 ***
## children 661.5105 105.2784 6.283 4.48e-10 ***
## regionnorthwest -275.6659 347.2730 -0.794 0.427453
## regionsoutheast -826.1187 349.6181 -2.363 0.018275 *
## regionsouthwest -1164.8152 348.5123 -3.342 0.000854 ***
## smokeryes 13421.6370 435.9158 30.790 < 2e-16 ***
## bmi_severity -938.5116 420.5807 -2.231 0.025817 *
## smokeryes:bmi_severity 19912.6072 600.8493 33.141 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4419 on 1326 degrees of freedom
## Multiple R-squared: 0.8679, Adjusted R-squared: 0.8668
## F-statistic: 792.1 on 11 and 1326 DF, p-value: < 2.2e-16
And there you have it: \(R^2\) of 0.87 with residual range tightened. We’ve used out domain-understanding to make improvements to the model.
Decision Trees only apply to categorical data, and can’t be (neatly) applied to numeric data. There are, however, a couple of hybrids of regression modelling and decision trees which are oft overlooked, yet often preferable to just regression:
Whereas in decision trees the information gain is (often) measured by entropy, that only works for categorical data. For numeric data, we can use metrics such as standard deviation reduction (SDR).
Benefits include no need for pre-processing or normalisation, and they can handle messy data well (up to a point, of course).
Let’s try these out in an example.
We’ve some white wine data made up of 11 physicochemical properties and one response variable (quality) for almost 5000 wines. These were used in a research project measuring the efficacy of various machine learning techniques (multiple regression, neural nets, SVMs). The study found that SVMs offered significant improvements over the others - unfortunately the model is difficult to interpret. We’ll try to use (easier to understand models derived by) regression trees and model trees to see how well they compare.
wine = read.csv("D://dev/R/mlwr/chap6-regression/whitewines.csv")
str(wine)
## 'data.frame': 4898 obs. of 12 variables:
## $ fixed.acidity : num 6.7 5.7 5.9 5.3 6.4 7 7.9 6.6 7 6.5 ...
## $ volatile.acidity : num 0.62 0.22 0.19 0.47 0.29 0.14 0.12 0.38 0.16 0.37 ...
## $ citric.acid : num 0.24 0.2 0.26 0.1 0.21 0.41 0.49 0.28 0.3 0.33 ...
## $ residual.sugar : num 1.1 16 7.4 1.3 9.65 0.9 5.2 2.8 2.6 3.9 ...
## $ chlorides : num 0.039 0.044 0.034 0.036 0.041 0.037 0.049 0.043 0.043 0.027 ...
## $ free.sulfur.dioxide : num 6 41 33 11 36 22 33 17 34 40 ...
## $ total.sulfur.dioxide: num 62 113 123 74 119 95 152 67 90 130 ...
## $ density : num 0.993 0.999 0.995 0.991 0.993 ...
## $ pH : num 3.41 3.22 3.49 3.48 2.99 3.25 3.18 3.21 2.88 3.28 ...
## $ sulphates : num 0.32 0.46 0.42 0.54 0.34 0.43 0.47 0.47 0.47 0.39 ...
## $ alcohol : num 10.4 8.9 10.1 11.2 10.9 ...
## $ quality : int 5 6 6 4 6 6 6 6 6 7 ...
# What does our distribution look like?
hist(wine$quality)
# let's also dot the i's and cross the t's for weirdness:
summary(wine)
## fixed.acidity volatile.acidity citric.acid residual.sugar
## Min. : 3.800 Min. :0.0800 Min. :0.0000 Min. : 0.600
## 1st Qu.: 6.300 1st Qu.:0.2100 1st Qu.:0.2700 1st Qu.: 1.700
## Median : 6.800 Median :0.2600 Median :0.3200 Median : 5.200
## Mean : 6.855 Mean :0.2782 Mean :0.3342 Mean : 6.391
## 3rd Qu.: 7.300 3rd Qu.:0.3200 3rd Qu.:0.3900 3rd Qu.: 9.900
## Max. :14.200 Max. :1.1000 Max. :1.6600 Max. :65.800
## chlorides free.sulfur.dioxide total.sulfur.dioxide
## Min. :0.00900 Min. : 2.00 Min. : 9.0
## 1st Qu.:0.03600 1st Qu.: 23.00 1st Qu.:108.0
## Median :0.04300 Median : 34.00 Median :134.0
## Mean :0.04577 Mean : 35.31 Mean :138.4
## 3rd Qu.:0.05000 3rd Qu.: 46.00 3rd Qu.:167.0
## Max. :0.34600 Max. :289.00 Max. :440.0
## density pH sulphates alcohol
## Min. :0.9871 Min. :2.720 Min. :0.2200 Min. : 8.00
## 1st Qu.:0.9917 1st Qu.:3.090 1st Qu.:0.4100 1st Qu.: 9.50
## Median :0.9937 Median :3.180 Median :0.4700 Median :10.40
## Mean :0.9940 Mean :3.188 Mean :0.4898 Mean :10.51
## 3rd Qu.:0.9961 3rd Qu.:3.280 3rd Qu.:0.5500 3rd Qu.:11.40
## Max. :1.0390 Max. :3.820 Max. :1.0800 Max. :14.20
## quality
## Min. :3.000
## 1st Qu.:5.000
## Median :6.000
## Mean :5.878
## 3rd Qu.:6.000
## Max. :9.000
We can see that our distribution is normal, so like any random sample of a bunch of wines. There are no NAs and the ranges look (mostly) sound, except maybe residual.sugar, and the free + total sulfur.dioxides.
Let’s create our training and test datasets in a 3:1 ratio:
wine.train = wine[1:3750, ]
wine.test = wine[3751:4898, ]
We’ll be using the rpart package (recursive partitioning)’s implementation of regression trees (these are apparently the most faithful to the original idea in CART), which are well documented and support visualisation and evaluation of rpart models.
#install.packages("rpart")
#install.packages("rpart.plot")
library(rpart)
model.rpart = rpart(data = wine.train, quality ~ .)
# let's see some info:
model.rpart
## n= 3750
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 3750 2945.53200 5.870933
## 2) alcohol< 10.85 2372 1418.86100 5.604975
## 4) volatile.acidity>=0.2275 1611 821.30730 5.432030
## 8) volatile.acidity>=0.3025 688 278.97670 5.255814 *
## 9) volatile.acidity< 0.3025 923 505.04230 5.563380 *
## 5) volatile.acidity< 0.2275 761 447.36400 5.971091 *
## 3) alcohol>=10.85 1378 1070.08200 6.328737
## 6) free.sulfur.dioxide< 10.5 84 95.55952 5.369048 *
## 7) free.sulfur.dioxide>=10.5 1294 892.13600 6.391036
## 14) alcohol< 11.76667 629 430.11130 6.173291
## 28) volatile.acidity>=0.465 11 10.72727 4.545455 *
## 29) volatile.acidity< 0.465 618 389.71680 6.202265 *
## 15) alcohol>=11.76667 665 403.99400 6.596992 *
The output shows the tree. At the top we have the root node, showing all the records (3750). Then we have the most important predictor, alcohol level, splitting the branching. The * at the end of a line indicate a leaf-node. We can see a more detailed report using summary(), though a more interesting visual is provided by rpart.plot:
#summary(model.rpart)
library(rpart.plot)
rpart.plot(model.rpart, digits=3)
How cool is that? We can specify more modifiers too, such as type, fallen.leaves, etc:
rpart.plot(model.rpart, digits=3,
fallen.leaves=T, # show all leaves at the base of the diagram
type=3, # different graphical effects
extra=101) # add extra info in leaf nodes
Very different looking. And so on
Ok, so let’s try this out on our test set:
# predict() returns the estimated numeric value for the outcome var (quality, in this case)
pred.rpart = predict(model.rpart, wine.test)
summary(pred.rpart)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 4.545 5.563 5.971 5.893 6.202 6.597
Right away we can see that the min and max of our prediction seem too tight, after all:
summary(wine.test$quality)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 3.000 5.000 6.000 5.901 6.000 9.000
The true range is [3, 9]. This suggests that the model isn’t correctly identifying extreme cases. A quick and effective way of assessing the model’s performance is by checking the correlation between predicted and actual quality values:
cor(wine.test$quality, pred.rpart)
## [1] 0.5369525
Meh, this is acceptable, but not mind blowing. Note that this isn’t a measure of how far off the predictions are from the true values. We can do that with the Mean Absolute Error, which is \(MAE = \frac{1}{n}\sum_{i=1}^{n}|e_i|\) where \(n\) is the number of predictions, \(e_i\) is the error for record i. Let’s write a quick function:
MAE = function(actual, predicted) {
mean(abs(actual - predicted))
}
MAE(wine.test$quality, pred.rpart)
## [1] 0.5872652
So this tells us the error between our model and the true values was 0.59, which, on a scale of 1-10, isn’t too bad. If we had to be picky though…
Let’s suppose we had a really dumb predictor that just predicted the mean from the training set. Its MAE would be:
# this is the MAE of a dumb classifier which always guesses the mean of the training set
MAE(mean(wine.train$quality), wine.test$quality)
## [1] 0.6719238
This is saying that a dumb predictor would predict with an error of 0.67 - so our regression tree, which has an MAE of 0.59, didn’t really do very much better, truth be told. As it happens, the researcher of this dataset found an MAE of 0.58 for the neural net model, and 0.45 for the SVM. This suggests that there is (only a little) room for improvement.
Our ‘improvement’ this time will be to build a different model haha - a model tree. Recall that these improve on regression trees by replacing the leaf nodes with regression models, often leading to more accurate results. The current best model tree is the M5’ algo, which is available in R’s RWeka package
#install.packages("RWeka")
library(RWeka)
# similar syntax to rpart
model.m5p = M5P(data = wine.train, quality ~ .)
# let's see the tree
# model.m5p
# that command's actually VERY verbose, so here are the first few lines:
#alcohol <= 10.85 :
#| volatile.acidity <= 0.238 :
#| | fixed.acidity <= 6.85 : LM1 (406/66.024%)
#| | fixed.acidity > 6.85 :
#| | | free.sulfur.dioxide <= 24.5 : LM2 (113/87.697%)
#| | | free.sulfur.dioxide > 24.5 :
#| | | | alcohol <= 9.15 :
#| | | | | citric.acid <= 0.305 :
#
# bla bla bla
# LM num: 1
# quality =
# 0.266 * fixed.acidity
# - 2.3082 * volatile.acidity
# - 0.012 * citric.acid
# + 0.0421 * residual.sugar
# + 0.1126 * chlorides
# + 0 * free.sulfur.dioxide
# - 0.0015 * total.sulfur.dioxide
# - 109.8813 * density
# + 0.035 * pH
# + 1.4122 * sulphates
# - 0.0046 * alcohol
# + 113.1021
# LM num: 2
# quality =
# -0.2557 * fixed.acidity
# - 0.8082 * volatile.acidity
# - 0.1062 * citric.acid
# + 0.0738 * residual.sugar
# + 0.0973 * chlorides
# + 0.0006 * free.sulfur.dioxide
# + 0.0003 * total.sulfur.dioxide
# - 210.1018 * density
# + 0.0323 * pH
# - 0.9604 * sulphates
# - 0.0231 * alcohol
# + 216.8857
We can see that once again, the alcohol content is the clincher, even for these. Note though that the leaf nodes aren’t predictions anymore - they’re these ‘LM num: 1’ things - LM = linear model. Summary for this model is actually a summary this time:
summary(model.m5p)
##
## === Summary ===
##
## Correlation coefficient 0.6666
## Mean absolute error 0.5151
## Root mean squared error 0.6614
## Relative absolute error 76.4921 %
## Root relative squared error 74.6259 %
## Total Number of Instances 3750
This summary ONLY shows the training data - we still need to run this on the test data:
pred.m5p = predict(model.m5p, wine.test)
# let's see if this has a wider range reflecting the [3,9] range we're expecting:
summary(pred.m5p)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 4.389 5.430 5.863 5.874 6.305 7.437
So this looks a little bit better - how does it correlate?
cor(pred.m5p, wine.test$quality)
## [1] 0.6272973
Ok, we’ve got better correlation, nice - and MAE…
MAE(pred.m5p, wine.test$quality)
## [1] 0.5463023
Ok, so an error of 0.546. This has now beaten the neural net (0.58) and is on the way to SVM status, whilst using a simpler learning method producing an easier to understand model. Yay for simplicity!