The idea of “Linear Regression Trees” is to grow a tree, similar to a decision tree, in which every end node is associated with linear regression for some or all of the variables in the data.
The first idea and implementation was done by Ross Quinlan (of C4.5 fame) in his M5 program. The following packages contain implementations for building linear regression trees:
Package ‘Cubist’ with the cubist() function
(Cubist was the name of the tool that Quinlan sold through his RuleQuest company.)
Package ‘partykit’ provides functions mob() and lmtree() for
“Model-based recursive partitioning based on least squares regression.”
M5P() (“M5 Prime”) in the RWeka package (part of the Weka software),
a reimplementation of the M5 algorithm in Java.
We will use the “Boston Housing” data as an example.
library(mlbench)
data("BostonHousing")
House = BostonHousing[, -14]
value = BostonHousing$medv
library(Cubist)
The cubist() function takes as input a data frame (or matrix) and numerical output. We can decide about the number of ‘committees’ it will apply.
mod1 = cubist(x = House, y = value)
mod10 = cubist(x = House, y = value, committees = 10)
The summary() shows us the rules of the generated model. We can see the rules of the tree and the regression equations at the nodes.
summary(mod1)
Call:
cubist.default(x = House, y = value)
Cubist [Release 2.07 GPL Edition] Mon Dec 7 19:47:25 2020
---------------------------------
Target attribute `outcome'
Read 506 cases (14 attributes) from undefined.data
Model:
Rule 1: [101 cases, mean 13.84, range 5 to 27.5, est err 1.98]
if
nox > 0.668
then
outcome = -1.11 + 2.93 dis + 21.4 nox - 0.33 lstat + 0.008 b
- 0.13 ptratio - 0.02 crim - 0.003 age + 0.1 rm
Rule 2: [203 cases, mean 19.42, range 7 to 31, est err 2.10]
if
nox <= 0.668
lstat > 9.59
then
outcome = 23.57 + 3.1 rm - 0.81 dis - 0.71 ptratio - 0.048 age
- 0.15 lstat + 0.01 b - 0.0041 tax - 5.2 nox + 0.05 crim
+ 0.02 rad
Rule 3: [43 cases, mean 24.00, range 11.9 to 50, est err 2.56]
if
rm <= 6.226
lstat <= 9.59
then
outcome = 1.18 + 3.83 crim + 4.3 rm - 0.06 age - 0.11 lstat - 0.003 tax
- 0.09 dis - 0.08 ptratio
Rule 4: [163 cases, mean 31.46, range 16.5 to 50, est err 2.78]
if
rm > 6.226
lstat <= 9.59
then
outcome = -4.71 + 2.22 crim + 9.2 rm - 0.83 lstat - 0.0182 tax
- 0.72 ptratio - 0.71 dis - 0.04 age + 0.03 rad - 1.7 nox
+ 0.008 zn
Evaluation on training data (506 cases):
Average |error| 2.10
Relative |error| 0.32
Correlation coefficient 0.94
Attribute usage:
Conds Model
80% 100% lstat
60% 92% nox
40% 100% rm
100% crim
100% age
100% dis
100% ptratio
80% tax
72% rad
60% b
32% zn
Time: 0.0 secs
Use “root mean squared error” (RMSE) to compare the two models. Asking for more committees wil improve the model significantly,
rmse(value, predict(mod1, House))
[1] 3.025956
rmse(value, predict(mod10, House))
[1] 2.53904
while the plot of actual versus fitted prices does not show such a clear advantage.
plot(value, predict(mod1, House), col = "black")
points(value, predict(mod10, House), col = "red")
grid()
lmtree in package ‘partykit’’partykit has two functions, mob() and lmtree(). The algorithmic work is performed by mob(), while lmtree() simplifies the user call.
library(partykit)
Loading required package: grid
Loading required package: libcoin
Loading required package: mvtnorm
The formula for lmtree() in general looks like:
y ~ z1 + ... + zl or
y ~ x1 + ... + xk | z1 + ... + zl
where the z1, ..., zl are the variables are used for building the tree, and the x1, ..., xk are used in the linear regressions; these two sets can be overlapping.
mod3 = lmtree(medv ~ . | .,data = BostonHousing)
rmse(value, predict(mod3, House))
[1] 3.241452
plot(value, predict(mod3, House), col = "black")
points(value, predict(mod10, House), col = "red")
grid()
# plot(value, abs(value - predict(mod10, House)), type ='h')
mape3 = round(100 * abs(value - predict(mod3, House)) / value, 3)
mape10 = round(100 * abs(value - predict(mod10, House)) / value, 3)
plot( value, mape10, pch = 20, col = "black")
points(value, mape3, col = 2)
grid()
For a comparison, we will apply a default Random Forest on the data and calculate the RMS error.
rf = randomForest(medv ~ ., data = BostonHousing)
rf
Call:
randomForest(formula = medv ~ ., data = BostonHousing)
Type of random forest: regression
Number of trees: 500
No. of variables tried at each split: 4
Mean of squared residuals: 10.07977
% Var explained: 88.06
rmse(value, predict(rf))
[1] 3.174865
Of course, this is not a reliable accuracy value, it only shows that Quinlans M5 in ‘Cubist’ generates an excellent fit.