The idea

First things first - regression modelling isn’t just one algorithm, but a family within generalised linear models (GLM). We’ll discuss only the most basic ones here.

We essentially have a dependent variable \(y\) which we wish to model as a function of one or more independent variables, such that we can make predictions for \(y\).

Such models are used for:

The details

The most basic, simple linear regression, uses a formula like \(y = \alpha + \beta x\). This is a simple linear regression model. Multiple regression would be more like \(y = \alpha + \beta_1 x_1 + \beta_2 x_2 + \beta_3 x_3\) (here y is a function of multiple independent variables).

We typically use ordrinary least squares estimation (OLS) to estimate \(\alpha\) and \(\beta\). This minimises the residuals (distance between an actual data-point and the predicted value on the regression line).

For simple linear regressions, after a bit of arithmetic we’d get:

\(\alpha = \hat{y} - \beta \hat{x}\) and \(\beta = \frac{\sum (x_i - \hat{x}) (y_i - \hat{y})}{\sum(x_i - \hat{x})^2} == \frac{Cov(x, y)}{Var(x)}\)

And for multiple linear regression model above, we could just re-write that formula mentioned earlier as

\(y = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + ... + \beta_n x_n + \epsilon\)

Or in condensed, matrix form:

\(Y = X \beta + \epsilon\)

where Y is a column vector of i elements (each example is an i), X is an i x n matrix (first column, \(X_0\) is all-ones for \(\beta_0\)), and \(\epsilon\) is another column vector of i elements. OLS can be applied here too to give:

\(\hat{\beta} = (X^{T}X)^{-1} X^{T} Y\)

All of these are readily solved with R, so let’s go through a couple of quick linear regression examples of the Challenger disaster.

Example - Challenger data

Data is available here. This is information which was available to the Challenger mission the night before the launch. The shuttle had 6 O-rings, none of which were allowed to fail. These O-rings are sensitive to temperature, becoming brittle with the cold. They were concerned that temperatures in the morning would be around 30 degrees F, but there was a lot of political pressure to carry out the launch.

# Alright, let's do a simple linear regression example first:

launch = read.csv("D://dev//R//mlwr/chap6-regression/challenger.csv")
str(launch)
## 'data.frame':    23 obs. of  5 variables:
##  $ o_ring_ct  : int  6 6 6 6 6 6 6 6 6 6 ...
##  $ distress_ct: int  0 1 0 0 0 0 0 0 1 1 ...
##  $ temperature: int  66 70 69 68 67 72 73 70 57 63 ...
##  $ pressure   : int  50 50 50 50 50 50 100 100 200 200 ...
##  $ launch_id  : int  1 2 3 4 5 6 7 8 9 10 ...
# o_ring_ct is the total number of rings present for each launch
# distress_ct is the number of distressed O-rings

# The str() results show that it's all integer data, let's see more:
summary(launch)
##    o_ring_ct  distress_ct      temperature       pressure    
##  Min.   :6   Min.   :0.0000   Min.   :53.00   Min.   : 50.0  
##  1st Qu.:6   1st Qu.:0.0000   1st Qu.:67.00   1st Qu.: 75.0  
##  Median :6   Median :0.0000   Median :70.00   Median :200.0  
##  Mean   :6   Mean   :0.3043   Mean   :69.57   Mean   :152.2  
##  3rd Qu.:6   3rd Qu.:0.5000   3rd Qu.:75.00   3rd Qu.:200.0  
##  Max.   :6   Max.   :2.0000   Max.   :81.00   Max.   :200.0  
##    launch_id   
##  Min.   : 1.0  
##  1st Qu.: 6.5  
##  Median :12.0  
##  Mean   :12.0  
##  3rd Qu.:17.5  
##  Max.   :23.0
# we'll try to model distress_ct as a function of temperature

# controversial, lowest temperature previously observed was 53. Let's plot this:
library(ggplot2)
ggplot(data=launch,
        aes(x=temperature, y=distress_ct)) +
        geom_point() +
        geom_smooth(method='lm', formula='y ~ x') # this draws a regression line

# That doesn't look promising...
# What's correlation between distress_ct and temperature?
cor(launch$distress_ct, launch$temperature)
## [1] -0.725671
# Ok, so it's pretty strong, and negative (i.e. drop temperature -> increase distress_ct)

# Linear model: y = alpha + beta * x
# from the notes above, recall that:
# alpha = y-mean - beta*x-mean
# beta = cov(x, y)/var(x)
beta = cov(launch$temperature, launch$distress_ct) / var(launch$temperature)
alpha = mean(launch$distress_ct) - beta * mean(launch$temperature)
c(alpha, beta)
## [1]  4.30158730 -0.05746032
# R can do a linear-regression for us using lm() (for linear-model).
# Let's compare our alpha/beta with R's lm() method:
slm = lm(data=launch, formula = distress_ct ~ temperature)
slm
## 
## Call:
## lm(formula = distress_ct ~ temperature, data = launch)
## 
## Coefficients:
## (Intercept)  temperature  
##     4.30159     -0.05746
# As if by magic!
# So what's the predicted value then for our expected 30F temperature?

predicted_ring_fails = predict(slm, data.frame(temperature=30))
predicted_ring_fails
##        1 
## 2.577778
# so we'd expect 2.6 rings to fail. Not good.

What if we wanted to discover other correlations within our dataset? Welcome to GGally::ggpairs - a great tool to draw a scatterplot matrix:

#install.packages("GGally")
library(GGally)
ggpairs(launch)
## Warning in cor(x, y, method = method, use = use): the standard deviation is
## zero
## Warning in cor(x, y, method = method, use = use): the standard deviation is
## zero
## Warning in cor(x, y, method = method, use = use): the standard deviation is
## zero
## Warning in cor(x, y, method = method, use = use): the standard deviation is
## zero

Here we can see that pressure also has a moderate correlation with distress_ct - not surprising.

Let’s finish with an example of multiple linear regression. We’ll do this two ways, one using our own very basic hand-crafted function, and then using lm(). For the former, recall that \(\hat{\beta} = (X^{T}X)^{-1} X^{T} Y\)

# This function returns the beta coefficients for a multiple linear regression
# @param y: vector of dependent variable values
# @param x: frame of independent variable values
solve_multiple_regression = function(y, x) {
  x = as.matrix(x)            # coerces x into matrix form (from, say, data.frame or vector)
  x = cbind(Intercept = 1, x) # adds a feature to x called Intercept, set to 1, in every record of frame x
  # Lots of new functions:
  #   t(x)   returns the transpose of x
  #   %*%      performs matrix multiplication
  #   solve(x) returns the inverse of x
  #      X^T      X        X^T     Y
  solve( t(x) %*% x ) %*% t(x) %*% y
}

# Let's give her a whirl - let's test it by repeating that simple linear regression above:
solve_multiple_regression(y = launch$distress_ct,
                          x = launch[c("temperature")])
##                    [,1]
## Intercept    4.30158730
## temperature -0.05746032
# Identical to before!

# Let's try a multiple regression now:
solve_multiple_regression(y = launch$distress_ct,
                          x = launch[c("temperature", "pressure")])
##                     [,1]
## Intercept    4.045202762
## temperature -0.058247321
## pressure     0.002044586

The above function gives us some useful info, but R’s are far more informative:

# Finally, let's use R's lm():
mlmodel = lm(data=launch,
             formula = distress_ct ~ temperature + pressure)
mlmodel
## 
## Call:
## lm(formula = distress_ct ~ temperature + pressure, data = launch)
## 
## Coefficients:
## (Intercept)  temperature     pressure  
##    4.045203    -0.058247     0.002045
# Good, that gives us the same coefficients as our own function. Let's see what else:
summary(mlmodel)
## 
## Call:
## lm(formula = distress_ct ~ temperature + pressure, data = launch)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -0.55155 -0.17948 -0.07578  0.11829  0.92988 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  4.045203   0.807267   5.011 6.70e-05 ***
## temperature -0.058247   0.011363  -5.126 5.15e-05 ***
## pressure     0.002045   0.001175   1.739   0.0973 .  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.3758 on 20 degrees of freedom
## Multiple R-squared:  0.5888, Adjusted R-squared:  0.5477 
## F-statistic: 14.32 on 2 and 20 DF,  p-value: 0.0001382

As we can see, R’s lm() method gives us all kind of extra info too, including significance levels and \(R^2\)

Example - Medical Insurance

The data is available here.

ins = read.csv("d://dev//R//mlwr/chap6-regression/insurance.csv")
str(ins)
## 'data.frame':    1338 obs. of  7 variables:
##  $ age     : int  19 18 28 33 32 31 46 37 37 60 ...
##  $ sex     : Factor w/ 2 levels "female","male": 1 2 2 2 2 1 1 1 2 1 ...
##  $ bmi     : num  27.9 33.8 33 22.7 28.9 ...
##  $ children: int  0 1 3 0 0 0 1 3 2 0 ...
##  $ smoker  : Factor w/ 2 levels "no","yes": 2 1 1 1 1 1 1 1 1 1 ...
##  $ region  : Factor w/ 4 levels "northeast","northwest",..: 4 3 3 2 2 3 3 2 1 2 ...
##  $ charges : num  16885 1726 4449 21984 3867 ...
# So a few of these are factor, we'll see how R's regression tools deal with them

summary(ins)
##       age            sex           bmi           children     smoker    
##  Min.   :18.00   female:662   Min.   :15.96   Min.   :0.000   no :1064  
##  1st Qu.:27.00   male  :676   1st Qu.:26.30   1st Qu.:0.000   yes: 274  
##  Median :39.00                Median :30.40   Median :1.000             
##  Mean   :39.21                Mean   :30.66   Mean   :1.095             
##  3rd Qu.:51.00                3rd Qu.:34.69   3rd Qu.:2.000             
##  Max.   :64.00                Max.   :53.13   Max.   :5.000             
##        region       charges     
##  northeast:324   Min.   : 1122  
##  northwest:325   1st Qu.: 4740  
##  southeast:364   Median : 9382  
##  southwest:325   Mean   :13270  
##                  3rd Qu.:16640  
##                  Max.   :63770
# Roughly even split of boys/girls and regional variation
# Mostly non-smokers
# Wild BMI extremes
# good age range
# Overall looks like a pretty normal sample of the population as a whole

# we want to be able to tell how charges will vary as a function of other features.
# Let's see how charges are distributed
ggplot(data=ins,
       aes(x=charges)) +
  geom_histogram(binwidth=1000)

# pretty heavily skewed right. Let's look at logs
ggplot(data=ins,
       aes(x=charges)) +
  geom_histogram() +
  scale_x_log10()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

# we can see a bit more structure in there

# Let's run this through lm()
ins_mod = lm(data=ins, formula = charges ~ .)
summary(ins_mod)
## 
## Call:
## lm(formula = charges ~ ., data = ins)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -11304.9  -2848.1   -982.1   1393.9  29992.8 
## 
## Coefficients:
##                 Estimate Std. Error t value Pr(>|t|)    
## (Intercept)     -11938.5      987.8 -12.086  < 2e-16 ***
## age                256.9       11.9  21.587  < 2e-16 ***
## sexmale           -131.3      332.9  -0.394 0.693348    
## bmi                339.2       28.6  11.860  < 2e-16 ***
## children           475.5      137.8   3.451 0.000577 ***
## smokeryes        23848.5      413.1  57.723  < 2e-16 ***
## regionnorthwest   -353.0      476.3  -0.741 0.458769    
## regionsoutheast  -1035.0      478.7  -2.162 0.030782 *  
## regionsouthwest   -960.0      477.9  -2.009 0.044765 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 6062 on 1329 degrees of freedom
## Multiple R-squared:  0.7509, Adjusted R-squared:  0.7494 
## F-statistic: 500.8 on 8 and 1329 DF,  p-value: < 2.2e-16

Let’s stop for a moment and examine the output above:

The problem with linear regression is that it requires the user to know a bit about the field to get the most out of it, unlike classifications, which did it all on their own.

We can improve this further.

It’s unlikely that the charges are a linear function of age - the older you get, the more likely it is you’ll have increased healthcare costs. We can probably improve the model by using age squared:

ins$age2 = ins$age^2
ins_mod_age2 = lm(data=ins, formula = charges ~ .)
summary(ins_mod_age2)
## 
## Call:
## lm(formula = charges ~ ., data = ins)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -11665.1  -2855.8   -944.1   1295.9  30826.0 
## 
## Coefficients:
##                  Estimate Std. Error t value Pr(>|t|)    
## (Intercept)     -6596.665   1689.444  -3.905 9.91e-05 ***
## age               -54.575     80.991  -0.674 0.500532    
## sexmale          -138.428    331.197  -0.418 0.676043    
## bmi               335.211     28.467  11.775  < 2e-16 ***
## children          642.024    143.617   4.470 8.47e-06 ***
## smokeryes       23859.745    410.988  58.055  < 2e-16 ***
## regionnorthwest  -367.812    473.783  -0.776 0.437692    
## regionsoutheast -1031.503    476.172  -2.166 0.030470 *  
## regionsouthwest  -957.546    475.417  -2.014 0.044198 *  
## age2                3.927      1.010   3.887 0.000107 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 6030 on 1328 degrees of freedom
## Multiple R-squared:  0.7537, Adjusted R-squared:  0.752 
## F-statistic: 451.6 on 9 and 1328 DF,  p-value: < 2.2e-16

So slight improvement in R2, residual errors not moved much either.

We can also state safely that BMI values up to 25 are ok, so it’s only really BMI values above 25 that are bad, and over 30 that are really bad. Let’s make that improvement

ins$bmi_severity = ifelse(ins$bmi > 30, 1, 0)
ins_mod_age2_bmi = lm(data=ins, formula = charges ~ .)
summary(ins_mod_age2_bmi)
## 
## Call:
## lm(formula = charges ~ ., data = ins)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -12525.0  -3381.3    160.9   1311.3  29258.3 
## 
## Coefficients:
##                  Estimate Std. Error t value Pr(>|t|)    
## (Intercept)     -2852.423   1825.191  -1.563 0.118336    
## age               -26.238     80.416  -0.326 0.744261    
## sexmale          -159.152    328.097  -0.485 0.627703    
## bmi               148.123     46.045   3.217 0.001327 ** 
## children          627.304    142.291   4.409 1.12e-05 ***
## smokeryes       23864.986    407.112  58.620  < 2e-16 ***
## regionnorthwest  -401.016    469.358  -0.854 0.393041    
## regionsoutheast  -884.246    472.549  -1.871 0.061534 .  
## regionsouthwest  -929.552    470.963  -1.974 0.048620 *  
## age2                3.573      1.003   3.562 0.000382 ***
## bmi_severity     2813.900    547.483   5.140 3.16e-07 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 5973 on 1327 degrees of freedom
## Multiple R-squared:  0.7585, Adjusted R-squared:  0.7567 
## F-statistic: 416.8 on 10 and 1327 DF,  p-value: < 2.2e-16

Finally, it’s likely that being overweight combined with smoking is a combination worse than its parts. This is called an interaction. Let’s add that in too:

ins_mod_fatsmokers = lm(data=ins, formula = 
                          charges ~ age2 + age + sex + bmi + children + region +
                          smoker*bmi_severity) # use * to represent interaction
summary(ins_mod_fatsmokers)
## 
## Call:
## lm(formula = charges ~ age2 + age + sex + bmi + children + region + 
##     smoker * bmi_severity, data = ins)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -4260.3 -1644.6 -1272.7  -784.7 24192.7 
## 
## Coefficients:
##                          Estimate Std. Error t value Pr(>|t|)    
## (Intercept)               69.2494  1353.2349   0.051 0.959195    
## age2                       3.5978     0.7422   4.847 1.40e-06 ***
## age                      -21.6786    59.4956  -0.364 0.715638    
## sexmale                 -475.6760   242.9293  -1.958 0.050430 .  
## bmi                      114.2920    34.0816   3.353 0.000821 ***
## children                 661.5105   105.2784   6.283 4.48e-10 ***
## regionnorthwest         -275.6659   347.2730  -0.794 0.427453    
## regionsoutheast         -826.1187   349.6181  -2.363 0.018275 *  
## regionsouthwest        -1164.8152   348.5123  -3.342 0.000854 ***
## smokeryes              13421.6370   435.9158  30.790  < 2e-16 ***
## bmi_severity            -938.5116   420.5807  -2.231 0.025817 *  
## smokeryes:bmi_severity 19912.6072   600.8493  33.141  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4419 on 1326 degrees of freedom
## Multiple R-squared:  0.8679, Adjusted R-squared:  0.8668 
## F-statistic: 792.1 on 11 and 1326 DF,  p-value: < 2.2e-16

And there you have it: \(R^2\) of 0.87 with residual range tightened. We’ve used out domain-understanding to make improvements to the model.

Regression Trees and Model Trees

Decision Trees only apply to categorical data, and can’t be (neatly) applied to numeric data. There are, however, a couple of hybrids of regression modelling and decision trees which are oft overlooked, yet often preferable to just regression:

Whereas in decision trees the information gain is (often) measured by entropy, that only works for categorical data. For numeric data, we can use metrics such as standard deviation reduction (SDR).

Benefits include no need for pre-processing or normalisation, and they can handle messy data well (up to a point, of course).

Let’s try these out in an example.

Example - Wine Quality Estimation

The data

We’ve some white wine data made up of 11 physicochemical properties and one response variable (quality) for almost 5000 wines. These were used in a research project measuring the efficacy of various machine learning techniques (multiple regression, neural nets, SVMs). The study found that SVMs offered significant improvements over the others - unfortunately the model is difficult to interpret. We’ll try to use (easier to understand models derived by) regression trees and model trees to see how well they compare.

wine = read.csv("D://dev/R/mlwr/chap6-regression/whitewines.csv")
str(wine)
## 'data.frame':    4898 obs. of  12 variables:
##  $ fixed.acidity       : num  6.7 5.7 5.9 5.3 6.4 7 7.9 6.6 7 6.5 ...
##  $ volatile.acidity    : num  0.62 0.22 0.19 0.47 0.29 0.14 0.12 0.38 0.16 0.37 ...
##  $ citric.acid         : num  0.24 0.2 0.26 0.1 0.21 0.41 0.49 0.28 0.3 0.33 ...
##  $ residual.sugar      : num  1.1 16 7.4 1.3 9.65 0.9 5.2 2.8 2.6 3.9 ...
##  $ chlorides           : num  0.039 0.044 0.034 0.036 0.041 0.037 0.049 0.043 0.043 0.027 ...
##  $ free.sulfur.dioxide : num  6 41 33 11 36 22 33 17 34 40 ...
##  $ total.sulfur.dioxide: num  62 113 123 74 119 95 152 67 90 130 ...
##  $ density             : num  0.993 0.999 0.995 0.991 0.993 ...
##  $ pH                  : num  3.41 3.22 3.49 3.48 2.99 3.25 3.18 3.21 2.88 3.28 ...
##  $ sulphates           : num  0.32 0.46 0.42 0.54 0.34 0.43 0.47 0.47 0.47 0.39 ...
##  $ alcohol             : num  10.4 8.9 10.1 11.2 10.9 ...
##  $ quality             : int  5 6 6 4 6 6 6 6 6 7 ...
# What does our distribution look like?
hist(wine$quality)

# let's also dot the i's and cross the t's for weirdness:
summary(wine)
##  fixed.acidity    volatile.acidity  citric.acid     residual.sugar  
##  Min.   : 3.800   Min.   :0.0800   Min.   :0.0000   Min.   : 0.600  
##  1st Qu.: 6.300   1st Qu.:0.2100   1st Qu.:0.2700   1st Qu.: 1.700  
##  Median : 6.800   Median :0.2600   Median :0.3200   Median : 5.200  
##  Mean   : 6.855   Mean   :0.2782   Mean   :0.3342   Mean   : 6.391  
##  3rd Qu.: 7.300   3rd Qu.:0.3200   3rd Qu.:0.3900   3rd Qu.: 9.900  
##  Max.   :14.200   Max.   :1.1000   Max.   :1.6600   Max.   :65.800  
##    chlorides       free.sulfur.dioxide total.sulfur.dioxide
##  Min.   :0.00900   Min.   :  2.00      Min.   :  9.0       
##  1st Qu.:0.03600   1st Qu.: 23.00      1st Qu.:108.0       
##  Median :0.04300   Median : 34.00      Median :134.0       
##  Mean   :0.04577   Mean   : 35.31      Mean   :138.4       
##  3rd Qu.:0.05000   3rd Qu.: 46.00      3rd Qu.:167.0       
##  Max.   :0.34600   Max.   :289.00      Max.   :440.0       
##     density             pH          sulphates         alcohol     
##  Min.   :0.9871   Min.   :2.720   Min.   :0.2200   Min.   : 8.00  
##  1st Qu.:0.9917   1st Qu.:3.090   1st Qu.:0.4100   1st Qu.: 9.50  
##  Median :0.9937   Median :3.180   Median :0.4700   Median :10.40  
##  Mean   :0.9940   Mean   :3.188   Mean   :0.4898   Mean   :10.51  
##  3rd Qu.:0.9961   3rd Qu.:3.280   3rd Qu.:0.5500   3rd Qu.:11.40  
##  Max.   :1.0390   Max.   :3.820   Max.   :1.0800   Max.   :14.20  
##     quality     
##  Min.   :3.000  
##  1st Qu.:5.000  
##  Median :6.000  
##  Mean   :5.878  
##  3rd Qu.:6.000  
##  Max.   :9.000

We can see that our distribution is normal, so like any random sample of a bunch of wines. There are no NAs and the ranges look (mostly) sound, except maybe residual.sugar, and the free + total sulfur.dioxides.

Model training

Let’s create our training and test datasets in a 3:1 ratio:

wine.train = wine[1:3750, ]
wine.test = wine[3751:4898, ]

We’ll be using the rpart package (recursive partitioning)’s implementation of regression trees (these are apparently the most faithful to the original idea in CART), which are well documented and support visualisation and evaluation of rpart models.

#install.packages("rpart")
#install.packages("rpart.plot")
library(rpart)

model.rpart = rpart(data = wine.train, quality ~ .)

# let's see some info:
model.rpart
## n= 3750 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 3750 2945.53200 5.870933  
##    2) alcohol< 10.85 2372 1418.86100 5.604975  
##      4) volatile.acidity>=0.2275 1611  821.30730 5.432030  
##        8) volatile.acidity>=0.3025 688  278.97670 5.255814 *
##        9) volatile.acidity< 0.3025 923  505.04230 5.563380 *
##      5) volatile.acidity< 0.2275 761  447.36400 5.971091 *
##    3) alcohol>=10.85 1378 1070.08200 6.328737  
##      6) free.sulfur.dioxide< 10.5 84   95.55952 5.369048 *
##      7) free.sulfur.dioxide>=10.5 1294  892.13600 6.391036  
##       14) alcohol< 11.76667 629  430.11130 6.173291  
##         28) volatile.acidity>=0.465 11   10.72727 4.545455 *
##         29) volatile.acidity< 0.465 618  389.71680 6.202265 *
##       15) alcohol>=11.76667 665  403.99400 6.596992 *

The output shows the tree. At the top we have the root node, showing all the records (3750). Then we have the most important predictor, alcohol level, splitting the branching. The * at the end of a line indicate a leaf-node. We can see a more detailed report using summary(), though a more interesting visual is provided by rpart.plot:

#summary(model.rpart)
library(rpart.plot)
rpart.plot(model.rpart, digits=3)

How cool is that? We can specify more modifiers too, such as type, fallen.leaves, etc:

rpart.plot(model.rpart, digits=3, 
           fallen.leaves=T, # show all leaves at the base of the diagram
           type=3,          # different graphical effects
           extra=101)       # add extra info in leaf nodes

Very different looking. And so on

Model performance

Ok, so let’s try this out on our test set:

# predict() returns the estimated numeric value for the outcome var (quality, in this case)
pred.rpart = predict(model.rpart, wine.test)
summary(pred.rpart)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   4.545   5.563   5.971   5.893   6.202   6.597

Right away we can see that the min and max of our prediction seem too tight, after all:

summary(wine.test$quality)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   3.000   5.000   6.000   5.901   6.000   9.000

The true range is [3, 9]. This suggests that the model isn’t correctly identifying extreme cases. A quick and effective way of assessing the model’s performance is by checking the correlation between predicted and actual quality values:

cor(wine.test$quality, pred.rpart)
## [1] 0.5369525

Meh, this is acceptable, but not mind blowing. Note that this isn’t a measure of how far off the predictions are from the true values. We can do that with the Mean Absolute Error, which is \(MAE = \frac{1}{n}\sum_{i=1}^{n}|e_i|\) where \(n\) is the number of predictions, \(e_i\) is the error for record i. Let’s write a quick function:

MAE = function(actual, predicted) {
  mean(abs(actual - predicted))
}

MAE(wine.test$quality, pred.rpart)
## [1] 0.5872652

So this tells us the error between our model and the true values was 0.59, which, on a scale of 1-10, isn’t too bad. If we had to be picky though…

Let’s suppose we had a really dumb predictor that just predicted the mean from the training set. Its MAE would be:

# this is the MAE of a dumb classifier which always guesses the mean of the training set
MAE(mean(wine.train$quality), wine.test$quality)
## [1] 0.6719238

This is saying that a dumb predictor would predict with an error of 0.67 - so our regression tree, which has an MAE of 0.59, didn’t really do very much better, truth be told. As it happens, the researcher of this dataset found an MAE of 0.58 for the neural net model, and 0.45 for the SVM. This suggests that there is (only a little) room for improvement.

Improving model performance

Our ‘improvement’ this time will be to build a different model haha - a model tree. Recall that these improve on regression trees by replacing the leaf nodes with regression models, often leading to more accurate results. The current best model tree is the M5’ algo, which is available in R’s RWeka package

#install.packages("RWeka")
library(RWeka)
# similar syntax to rpart
model.m5p = M5P(data = wine.train, quality ~ .)

# let's see the tree
# model.m5p
# that command's actually VERY verbose, so here are the first few lines:
#alcohol <= 10.85 : 
#|   volatile.acidity <= 0.238 : 
#|   |   fixed.acidity <= 6.85 : LM1 (406/66.024%)
#|   |   fixed.acidity >  6.85 : 
#|   |   |   free.sulfur.dioxide <= 24.5 : LM2 (113/87.697%)
#|   |   |   free.sulfur.dioxide >  24.5 : 
#|   |   |   |   alcohol <= 9.15 : 
#|   |   |   |   |   citric.acid <= 0.305 : 
#
# bla bla bla
# LM num: 1
# quality = 
#   0.266 * fixed.acidity 
#   - 2.3082 * volatile.acidity 
#   - 0.012 * citric.acid 
#   + 0.0421 * residual.sugar 
#   + 0.1126 * chlorides 
#   + 0 * free.sulfur.dioxide 
#   - 0.0015 * total.sulfur.dioxide 
#   - 109.8813 * density 
#   + 0.035 * pH 
#   + 1.4122 * sulphates 
#   - 0.0046 * alcohol 
#   + 113.1021

# LM num: 2
# quality = 
#   -0.2557 * fixed.acidity 
#   - 0.8082 * volatile.acidity 
#   - 0.1062 * citric.acid 
#   + 0.0738 * residual.sugar 
#   + 0.0973 * chlorides 
#   + 0.0006 * free.sulfur.dioxide 
#   + 0.0003 * total.sulfur.dioxide 
#   - 210.1018 * density 
#   + 0.0323 * pH 
#   - 0.9604 * sulphates 
#   - 0.0231 * alcohol 
#   + 216.8857

We can see that once again, the alcohol content is the clincher, even for these. Note though that the leaf nodes aren’t predictions anymore - they’re these ‘LM num: 1’ things - LM = linear model. Summary for this model is actually a summary this time:

summary(model.m5p)
## 
## === Summary ===
## 
## Correlation coefficient                  0.6666
## Mean absolute error                      0.5151
## Root mean squared error                  0.6614
## Relative absolute error                 76.4921 %
## Root relative squared error             74.6259 %
## Total Number of Instances             3750

This summary ONLY shows the training data - we still need to run this on the test data:

pred.m5p = predict(model.m5p, wine.test)
# let's see if this has a wider range reflecting the [3,9] range we're expecting:
summary(pred.m5p)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   4.389   5.430   5.863   5.874   6.305   7.437

So this looks a little bit better - how does it correlate?

cor(pred.m5p, wine.test$quality)
## [1] 0.6272973

Ok, we’ve got better correlation, nice - and MAE…

MAE(pred.m5p, wine.test$quality)
## [1] 0.5463023

Ok, so an error of 0.546. This has now beaten the neural net (0.58) and is on the way to SVM status, whilst using a simpler learning method producing an easier to understand model. Yay for simplicity!