Machine Learning

In this tutorial we describe a set of analyses that though disparate are often collected under the heading machine learning. This subject can loosely be described as either finding patterns in data, or building predictive models with ‘high dimensional data’ - that is data with many variables. Often these analyses are useful for very large datasets as they scale better than standard statistical or graphical techniques (NB not always). The first of these ‘finding patterns’ is often called unsupervised learning (or sometimes data mining) as we do not characterise or label data before looking for patterns - e.g. creating a clustered heatmap of gene expression. By contrast supervised learning (also known as classification) uses known data to make ‘predictive models’ that can be used on new data - the classic example application is a model that recognises spam email.

Whilst machine learning is sometimes thought a subject more dominated by engineering or computer science- it has some overlap but shouldn’t be confused with AI - in many ways it is more helpfully viewed as a branch of statistics. It used to be the case that software such as Matlab and SAS had the more powerful machine learning features - but I think R is now in the ascendancy.

Unsupervised Learning: Clustering

With many types of data we maybe interested in finding if there are groups of similar samples or records. If we have a large number of variables then it maybe challenging to discover this graphically as you cannot plot more than a few variable dimensions on paper (x-axis, y-axis, colour, then point-size could be used for continuous variable scaling but that’s about it). This is what is meant by high dimensional data. To find groups we therefore need to first create a set of distance measures to compare samples and second create a rule to cluster or group samples based on these distance measures. All clustering methods (that I am aware of) follow this basic procedure.

One of the most widely used clustering methods is the hierarchical cluster here we look for similar groups of cars amongst the 1974 US Motor Trend survey of car specs that we looked at earlier:

# depending on the distance metric scaling is sometimes helpful. The
# euclidean distance is not standardised
mtcars.sc = scale(mtcars)
head(mtcars.sc)
##                       mpg    cyl     disp      hp    drat      wt    qsec
## Mazda RX4          0.1509 -0.105 -0.57062 -0.5351  0.5675 -0.6104 -0.7772
## Mazda RX4 Wag      0.1509 -0.105 -0.57062 -0.5351  0.5675 -0.3498 -0.4638
## Datsun 710         0.4495 -1.225 -0.99018 -0.7830  0.4740 -0.9170  0.4260
## Hornet 4 Drive     0.2173 -0.105  0.22009 -0.5351 -0.9661 -0.0023  0.8905
## Hornet Sportabout -0.2307  1.015  1.04308  0.4129 -0.8352  0.2277 -0.4638
## Valiant           -0.3303 -0.105 -0.04617 -0.6080 -1.5646  0.2481  1.3270
##                       vs      am    gear    carb
## Mazda RX4         -0.868  1.1899  0.4236  0.7352
## Mazda RX4 Wag     -0.868  1.1899  0.4236  0.7352
## Datsun 710         1.116  1.1899  0.4236 -1.1222
## Hornet 4 Drive     1.116 -0.8141 -0.9318 -1.1222
## Hornet Sportabout -0.868 -0.8141 -0.9318 -0.5030
## Valiant            1.116 -0.8141 -0.9318 -1.1222
# the default dist is euclidean
d = dist(mtcars.sc)
# We wont print d here as it is messy and scrolls off the page. You can if
# you like.
class(d)
## [1] "dist"

Above is step one creating a set of distances. By default the dist function calculates the euclidean distances between all the rows of a matrix, which in the case of the mtcars data is the dissimilarity of the car models. As euclidean distance is non-standardised I scale the data so that the variables importance to the distance metric are equally weighted. There are many other types of distance function including the popular 1-correlation (correlation is similarity subtracting 1 makes it dissimilarity) and mahalanobis. The result d is of class dist, a special table format which many R functions will recognise.

hc = hclust(d, method = "complete")
hc
## 
## Call:
## hclust(d = d, method = "complete")
## 
## Cluster method   : complete 
## Distance         : euclidean 
## Number of objects: 32 
## 

The hclust function accepts a ‘dist’ class object and uses this to iteratively agglomerate all samples from the point where all cars are in a separate cluster, through succesive joining of the most similar clusters, to the point where all clusters are joined. The method=’complete’ argument specifies that the closeness of clusters is determined from the centre of a cluster i.e. if one cluster contains 3 cars and another contains 2 the average distance between the clusters is used. Other methods include minimum or maximum where unsurprisingly the minimum or maximum distance of cars in groups is used to determine the next join. This is more easily understood graphically using a dendrogram:

plot(hc, hang = -1)

plot of chunk plotCluster

The plot function automatically recognises hclust class objects and produces a tree like plot called a dendrogram (on next page). The form of this graph closely follows the algorithm of agglomeration with iterative joining of groups from the bottom to a final root join at the top that unifies all clusters. Of course, totally random data - indeed any data - will produce a dendrogram. There is always a difficulty with unsupervised learning - as there is essentially no-hypothesis it is impossible to say whether a pattern is significant. Nevertheless as a guide to whether our clustering is meaningful we can calculate the cophenetic distance, the dissimilarity in distances between our input distance table and the distances that are shown in the dendrogram.

cophenetic.d = cophenetic(hc)
cor.test(d, cophenetic.d)
## 
##  Pearson's product-moment correlation
## 
## data:  d and cophenetic.d 
## t = 27.66, df = 494, p-value < 2.2e-16
## alternative hypothesis: true correlation is not equal to 0 
## 95 percent confidence interval:
##  0.7424 0.8118 
## sample estimates:
##    cor 
## 0.7795 
## 

If there is no real information in our clustering then the original and after clustering distances will be uncorrelated, if they are correlated then we have captured real meaning. We have a high cophenetic correlation and significance so our clustering is real - which is to say that it accurately summarises the data, no more than that.

Plotting high dimensional data:

We will return later to some other clustering methods but first consider some graphical representations of high dimensional data. The dendrogram above is a useful representation of the clustering but it does not reveal the relation between variables in the data. The heatmap does show the relation between varaibles compressing all the information into a table like plot with a colour scale.

library(gplots)
## Loading required package: gtools
## Loading required package: gdata
## gdata: read.xls support for 'XLS' (Excel 97-2004) files ENABLED.
## NA
## gdata: Unable to load perl libaries needed by read.xls() gdata: to support
## 'XLSX' (Excel 2007+) files.
## NA
## gdata: Run the function 'installXLSXsupport()' gdata: to automatically
## download and install the perl gdata: libaries needed to support Excel XLS
## and XLSX formats.
## Attaching package: 'gdata'
## The following object(s) are masked from 'package:stats':
## 
## nobs
## The following object(s) are masked from 'package:utils':
## 
## object.size
## Loading required package: caTools
## Loading required package: bitops
## Loading required package: grid
## Loading required package: KernSmooth
## KernSmooth 2.23 loaded Copyright M. P. Wand 1997-2009
## Loading required package: MASS
## Attaching package: 'gplots'
## The following object(s) are masked from 'package:stats':
## 
## lowess
heatmap.2(as.matrix(mtcars), trace = "n", scale = "col", Colv = F, 
    Rowv = F, margins = c(4, 8))
## Warning: Discrepancy: Rowv is FALSE, while dendrogram is `both'. Omitting
## row dendogram.
## Warning: Discrepancy: Colv is FALSE, while dendrogram is `none'. Omitting
## column dendogram.

plot of chunk heatmap

The heatmap.2 function takes a matrix as it’s first argument and shows a colour scaled map of the matrix table. Here we have also chosen to scale the values by column again as it doesn’t make sense to compare different scaled car variables to each other. The margins argument is a display parameter for the size of margin about the plot. The heatmap is useful as it packs all the information on the variables into a small space. In this plot clustering has been suppressed by the Colv=F and Rowv=F arguments.

heatmap.2(as.matrix(mtcars), trace = "n", scale = "col", margins = c(4, 
    8))

plot of chunk heatmapClustered

Here we have clustered both the cars and the variables because the relation between say miles per gallon(mpg) and cylinders (cyl) maybe interesting. Strictly speaking this plot has no more information than the first heatmap, it is just a rearrangement, and yet it is more 'informative'. Of course when we have a huge dataset it may become impossible even to view the data in a clustered heatmap (without scrolling on a computer screen), then we need to try a graphical method of dimensional reduction.

Dimensional Reduction: multidimensional scaling (MDS)

The simplest (conceptually) form of dimensional reduction is the multidimensional scaling- or MDS. Naively we might say if we cannot plot all the samples and variables then we can at least plot the distance matrix which describes the relation between samples (i.e. such as we used for clustering above). However it doesn’t follow that a set of distances calculated between samples with many variable dimensions (e.g. n>>2) can be plotted on a 2 dimensional screen. Consider 4 cities London, Paris, Tokyo and Chicago and a table of the flying distances between. We couldn’t simply plot points on a 2-dimensional surface that would represent those cities and their distance apart- no more than we could peel of the surface of a globe and flatten it onto a piece of paper. And just as you cannot plot 3D onto a 2D plane you cannot plot n-dimensional distances either– or not without distortion.

The MDS however finds the 2-dimensional plot co-ordinates which least distorts the higher dimensional distance table.

library(ggplot2)
cmd = cmdscale(d)
cmd.frame = data.frame(cars = rownames(mtcars), x = cmd[, 1], y = cmd[, 
    2])
# there is a lot of overplotting of labels unfortunately
qplot(x = x, y = y, data = cmd.frame, label = cars, col = I("red"), 
    size = I(4)) + geom_text(, vjust = 1, col = "black")

plot of chunk cmdscale

People sometimes ask: What are the x and y axis of the MDS plot? How should we label them? In truth it doesn’t matter, it’s just an arbitrary scale.

Principal Component Analysis

Similarly to MDS the principal component analysis (PCA) can be used to reduce high dimensional data for plotting but it’s purpose is slightly different. Rather than minimise distortion of the original distance relations, the PCA seeks to iteratively find axes through the data which have the highest variance, then the next perpendicular axes (NB it can rotate fully about the first axes whilst still being perpendicular) with the highest variance, and so on. The reasoning behind this is manifold. The axes or principal components can be viewed as composite varibles which capture the information of the other variables and reduce the dimensionality of the data. Alternatively you can use the principal components (particularly the first two) as axes for plotting because the data will be maximally dispersed along them.

# again note that PCA is sensitive to the scaling of the data
pca = prcomp(mtcars, scale = T)
print(summary(pca), digits = 2)
## Importance of components:
##                        PC1  PC2   PC3   PC4  PC5   PC6   PC7   PC8   PC9
## Standard deviation     2.6 1.63 0.792 0.519 0.47 0.460 0.368 0.351 0.278
## Proportion of Variance 0.6 0.24 0.057 0.025 0.02 0.019 0.012 0.011 0.007
## Cumulative Proportion  0.6 0.84 0.899 0.923 0.94 0.963 0.975 0.986 0.993
##                          PC10  PC11
## Standard deviation     0.2281 0.148
## Proportion of Variance 0.0047 0.002
## Cumulative Proportion  0.9980 1.000

The summary of the PCA here shows that the first 2 principal components capture 84% of the variance. We chose to scale the data because the variables are dissimilar in range but we want them to contribute equally to the PCA. If we now use the top two axes to plot our samples then we will capture as much variability as possible.

biplot(pca)

plot of chunk biplot

# plot on next page

The biplot goes a step further not only plotting the samples as points along the top 2 principal component axes but also plotting the variables as vectors. Simply put if two variables are at acute angles to each other then they are highly correlated.

Supervised Learning

In supervised learning we use example or training data of known type to build a model which we then use to make predictions of type about new data. Usually these are categorical predictions such as whether an email is spam or not, whether a patient gene expression pattern denotes a good or poor outcome, or whether a bank account is possibly fraudulent. This is also often termed classification.

Actually we have already gone over the basic R procedure for this earlier when we used logistic regression via the lm function to model the titanic survivor data- and then the predict function to calculate the survival probability of a hypothetical passenger. For larger datasets however different fitting models are sometimes appropriate and we usually estimate the classification error directly using held-back or test data as the potential for overfitting increases as do the number of predictor variables. Commonly if data is scarce we may estimate error by cross validation, iteratively splitting our data into randomly sampled training and test sets and estimating the classification error and confidence limits from these multiple folds of the data. Finally if there are a lot of variables we may also try to remove some from our model using feature selection.

That was a lot of verbiage to take in - hopefully the following examples will make it clearer. Unfortunately though whilst the series of steps described above are common for lots of different forms of machine learning model, the R language in contrast to its unified approach for say linear models (e.g. lm and glm) doesn't always have a consistent framework for the different types of machine learning model. Rather they are split across many different packages (e.g. e1071, class, rpart, randomForest) often with different arguments and data input formats, or model and predict class outputs. However there are two packages that try to integrate the varied models under a common framework and make comparison of different models consistent: ipred and MLInterfaces.

ipred

First we will prepare a prostate microarray dataset from the ElemStatLearn package that accomoanies the classic machine learning text The Elements of Statistical Learning. Each column is a tumour sample and each row a gene expression level (the genes are not labelled). Our aim is to create a model which predicts (classifies) tumour type based upon the many rows of gene expression predictors.

rm(cmd, cmd.frame, mtcars.sc, cophenetic.d, d, hc, pca)
library(ipred)
## Loading required package: rpart
## Loading required package: MASS
## Loading required package: mlbench
## Loading required package: survival
## Loading required package: splines
## Loading required package: nnet
## Loading required package: class
## Loading required package: prodlim
## Loading required package: KernSmooth
## KernSmooth 2.23 loaded Copyright M. P. Wand 1997-2009
library(ElemStatLearn)
nci = nci
# We are going to remove some predictors in a prefiltering. Strictly
# speaking this is a BAD THING. For the purposes of a tutorial however it
# makes sense to work with a smaller dataset. Note how the apply function
# calculates sd for each row without an explicit loop.
gene.sd = apply(nci, 1, sd)
# pick the gens with the highest variance. NB the filtering should not
# relate to the response or else the error estimate will be biased.
high.varying.genes = order(gene.sd, decreasing = T)[1:400]
nci = nci[high.varying.genes, ]
cl = as.factor(colnames(nci))
nci = t(nci)
nci = data.frame(cl = cl, nci)
## Warning: some row.names duplicated:
## 2,3,6,7,8,10,11,12,13,14,15,16,17,18,19,20,25,26,27,28,29,30,31,32,33,37,38,39,40,41,43,44,45,46,47,48,50,52,53,54,55,56,57,58,59,60,61,62,63,64
## --> row.names NOT used
head(nci[, 1:5])
##       cl     X1     X2     X3     X4
## 1    CNS  5.770  5.820  5.480 -1.910
## 2    CNS  5.040  5.070  5.730 -1.980
## 3    CNS -1.440 -1.030  4.630 -0.580
## 4  RENAL -2.420 -2.420  1.300 -1.880
## 5 BREAST -2.955 -2.955 -1.055 -1.235
## 6    CNS  0.000 -3.180  0.590 -3.050

We have transposed the data so that the each sample or tumour is a row or record not a column and made a column response variable cl from what was the column names. We will try a k-nearest neighbours (knn) model using the ipredknn function. Very simply for each test.class sample it calculates the euclidean distance from all the training samples and when \( k=1 \) it predicts the test sample the same as the closest training sample . When \( k > 1 \) it either predicts the majority vote or if tied predicts the test sample unknown. The knn is a good all-round classifier with simple parameters that makes no prior assumptions about the distribution of your data.

# there are 64 rows of data, that is 64 cancer samples
dim(nci)
## [1]  64 401
# we randomly pick two thirds of those rows
s = sample(1:64, 48)
# we train the model on 2/3rds of the data.. this may take a moment
knn.model = ipredknn(cl ~ ., data = nci[s, ], k = 1)
#  we make prediction on the remaining test data
knn.pred = predict.ipredknn(knn.model, newdata = nci[-s, ], type = "class")
# we compare the prediction to the true class
data.frame(knn.pred = knn.pred, cl = cl[-s])
##    knn.pred       cl
## 1       CNS      CNS
## 2       CNS      CNS
## 3     RENAL   BREAST
## 4       CNS      CNS
## 5       CNS      CNS
## 6     RENAL    RENAL
## 7  PROSTATE   BREAST
## 8  MELANOMA MELANOMA
## 9     COLON  OVARIAN
## 10 LEUKEMIA LEUKEMIA
## 11 LEUKEMIA LEUKEMIA
## 12 LEUKEMIA LEUKEMIA
## 13    COLON    COLON
## 14    COLON    COLON
## 15 MELANOMA MELANOMA
## 16 MELANOMA MELANOMA
# the correct prediction rate is calculated
mean(knn.pred == cl[-s])
## [1] 0.8125

Note that the form here follows that of linear models, specifying a response variable via the model formula (cl~. essentially says model cl on all other variables .), and then using the predict function to test on newdata. We estimate a correct prediction rate of 0.5 (NB for my random training/test sample). Estimating the error rate this way however is highly variable. We would be better repeating this many times- using cross validation for which we can use the 'errorest' function.

knn.mymodel = function(formula, data, k = 1) ipredknn(formula, data, 
    k = k)
knn.mypred = function(object, newdata) predict.ipredknn(object, newdata, 
    type = "class")
# this takes about 20 secs on my computer as the function loops through
# training and testing many models on a fairly large dataset. This is one
# reason R is sometimes eschewed for machine learning problems.
knn.cv = errorest(cl ~ ., data = nci, model = knn.mymodel, predict = knn.mypred)
knn.cv
## 
## Call:
## errorest.data.frame(formula = cl ~ ., data = nci, model = knn.mymodel, 
##     predict = knn.mypred)
## 
##   10-fold cross-validation estimator of misclassification error 
## 
## Misclassification error:  0.3438 
## 

You may notice this takes a few moments to run as we are in a loop of building and testing many models (10-fold cross validation by default). This code maybe a little unclear? Lets discuss how it works. Earlier I said that the inputs and outputs of different R machine learning models were inconsistent. The first two lines merely specify the model and the prediction method in a consistent way. I specify the parameters I want in my knn model, then I specify that the predict function will return class labels only. All prediction models must return only class labels in order for errorest to work . So if you have amachine learning method that returns some special class object then you need to write a wrapper function that extracts just the classification result to pass to the errorest function. Finally I call errorest with arguments for the data, the model function, and the predict function. Note that the error rate 0.4062 here is 1-prediction rate, so not so far from our first point estimate above (e.g. 1-0.625= 0.375).

Using this approach we can consistently compare models. This time we simply use the default parameters for the linear discriminant analysis (lda) function so have no need to specify our own version. We still have to create our own prediction function to extract the class values from the standard predict function.

# note just the class labels are extracted from the lda model object
lda.mypred = function(object, newdata) predict(object, newdata = newdata)$class
lda.cv = errorest(cl ~ ., data = nci, model = lda, predict = lda.mypred)
lda.cv
## 
## Call:
## errorest.data.frame(formula = cl ~ ., data = nci, model = lda, 
##     predict = lda.mypred)
## 
##   10-fold cross-validation estimator of misclassification error 
## 
## Misclassification error:  0.4688 
## 

The linear discriminant analysis is not so good for this data- or at least not with default parameters.

MLInterfaces

With feature selection the number of predictor variables used in each of round of model training is reduced by removing those unrelated to the response. If the data is being cross validated then the feature selection must be done a new in each fold or the error estimate will be underestimated. If you feature selected prior to cross validation you would be building an important part of the model on the whole data. You can perform feature selection with ipred by writing a few lines of code to make your own feature selecting variant of any model you like. However the MLInterfaces package handles much of this for you. First we will prepare another simpler test example from leukaemia gene expression data in the ALL package.

library(ALL)
## Loading required package: Biobase
## Loading required package: BiocGenerics
## Attaching package: 'BiocGenerics'
## The following object(s) are masked from 'package:gdata':
## 
## combine
## The following object(s) are masked from 'package:stats':
## 
## xtabs
## The following object(s) are masked from 'package:base':
## 
## Filter, Find, Map, Position, Reduce, anyDuplicated, cbind, colnames,
## duplicated, eval, get, intersect, lapply, mapply, mget, order, paste,
## pmax, pmax.int, pmin, pmin.int, rbind, rep.int, rownames, sapply, setdiff,
## table, tapply, union, unique
## Welcome to Bioconductor
## 
## Vignettes contain introductory material; view with 'browseVignettes()'. To
## cite Bioconductor, see 'citation("Biobase")', and for packages
## 'citation("pkgname")'.
data(ALL)
# first some data preparation so that we are only looking at BCR/ABL and
# NEG control samples
mol.biol = as.vector(pData(ALL)$mol.biol)
ALL = ALL[, which(mol.biol == "BCR/ABL" | mol.biol == "NEG")]
mol.biol = mol.biol[which(mol.biol == "BCR/ABL" | mol.biol == "NEG")]
# we use the mean absolute deviation to pre-filter the data this time
mads = apply(exprs(ALL), 1, mad)
ALL = ALL[order(mads, decreasing = TRUE)[1:800], ]
ALL.df = data.frame(mol.biol = mol.biol, t(exprs(ALL)))

The xvalSpec function controls the cross validation, here we specify one cross validation with 5-folds of balanced 1/5ths of the data left out as the test set, and a second with the same plus 50 features with the highest absolute T value selected for each fold. Thats is one cross-validation on the remainign data, and one with additional feature selection.

xval.procedure1 = xvalSpec(type = "LOG", niter = 5, partitionFunc = balKfold.xvspec(5))
## Error: could not find function "xvalSpec"
xval.procedure2 = xvalSpec(type = "LOG", niter = 5, partitionFunc = balKfold.xvspec(5), 
    fsFun = fs.absT(50))
## Error: could not find function "xvalSpec"

The MLearn function trains the model - with cross-validation and/or feature selection if these arguments are slected.

knn.mod1 = MLearn(mol.biol ~ ., data = ALL.df, .method = knnI(k = 5), 
    trainInd = xval.procedure1)
## Error: could not find function "MLearn"
knn.mod2 = MLearn(mol.biol ~ ., data = ALL.df, .method = knnI(k = 5), 
    trainInd = xval.procedure2)
## Error: could not find function "MLearn"
confuMat(knn.mod1)
## Error: could not find function "confuMat"
confuMat(knn.mod2)
## Error: could not find function "confuMat"

The confuMat function produces a table showing the breakdown of error, which is very slightly better when we use feature selection…for this data and this model at least.

Tree Methods

Lastly in this another lengthy tutorial we will take a quick look now at the popular tree based methods. They have their own graphical methods, they handle mixed data consistently (i.e. categorical, ordinal and numeric features), and they implicitly involve their own form of feature selection. We will use another dataset from the ElemStatLearn package: SAheart, a retrospective sample of patients in a high risk area for heart disease. Our response variable to be predicted is whether or not patients have heart disease.

The rpart algorithm (recursive partitioning) similar to stepwise multiple regression selects the single predictor variable which best splits ( or partition) the data into one of the two response groups: heart disease or no heart disease. In its simplest form the process is then repeated for each of the subgroups and so on. Until the subgroups either reach a minimum size or no improvement can be made.

The best split is usually defined by

the Gini index: \( 1-p^2_A-p^2_B) \)

where \( p_A^2 \) is the proportion of subgroup A and \( p_B^2 \) the proportion of class B. The more homogenous the split the more the gini score is minimised. We’ll show an example of this splitting in a moment.

library(rpart)
library(ElemStatLearn)
SAheart = SAheart
tree.mod1 = rpart(chd ~ ., data = SAheart, method = "class", cp = 0.01)
plot(tree.mod1, margin = 0.05)
text(tree.mod1, use.n = TRUE)

plot of chunk rpart

print(tree.mod1)
## n= 462 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 462 160 0 (0.65368 0.34632)  
##    2) age< 50.5 290  64 0 (0.77931 0.22069)  
##      4) age< 30.5 108   8 0 (0.92593 0.07407) *
##      5) age>=30.5 182  56 0 (0.69231 0.30769)  
##       10) typea< 68.5 170  46 0 (0.72941 0.27059) *
##       11) typea>=68.5 12   2 1 (0.16667 0.83333) *
##    3) age>=50.5 172  76 1 (0.44186 0.55814)  
##      6) famhist=Absent 82  33 0 (0.59756 0.40244)  
##       12) tobacco< 7.605 58  16 0 (0.72414 0.27586) *
##       13) tobacco>=7.605 24   7 1 (0.29167 0.70833) *
##      7) famhist=Present 90  27 1 (0.30000 0.70000)  
##       14) ldl< 4.99 39  18 1 (0.46154 0.53846)  
##         28) adiposity>=27.98 20   7 0 (0.65000 0.35000)  
##           56) tobacco< 4.15 10   1 0 (0.90000 0.10000) *
##           57) tobacco>=4.15 10   4 1 (0.40000 0.60000) *
##         29) adiposity< 27.98 19   5 1 (0.26316 0.73684) *
##       15) ldl>=4.99 51   9 1 (0.17647 0.82353) *

The plot function automatically recognises the rpart fit object and makes it’s own special dendrogram. The text function is then used to add the variable split thresholds to the branches and the response mixture at the terminal nodes. we use the print command to have a look at the response mixture at all the branches. The default splitting for rpart is the Gini index so the first split above will be all those with age <50.5. This node then contains 290 without heart disease (code =0) and 64 patients with heart disease (code =1).

or a Gini coefficient of \( 1 - p^2_A - p^2_B) = 1 - 0.779^2 - 0.221^2 = 0.344 \).

In this case the tree has continued splitting until a stopping criteria cp=0.01 (for complexity parameter) is met. The idea is to prevent overfitting, or subdividing the data into smaller and smaller splits that probably model the particularities of this sample rather than the population in general. Rather than arbitrarily pick a value we can introduce cross validation again to help us select the right complexity parameter (cp) value. First we build a model with cp=0 i.e. no stopping criteria:

tree.mod2 = rpart(chd ~ ., data = SAheart, method = "class", cp = 0, 
    xval = 10)
plot(tree.mod2)

plot of chunk rpartXval

# This tree is over-fit but the complexity parameter for each split will
# have been calculated.
printcp(tree.mod2)
## 
## Classification tree:
## rpart(formula = chd ~ ., data = SAheart, method = "class", cp = 0, 
##     xval = 10)
## 
## Variables actually used in tree construction:
## [1] adiposity age       alcohol   famhist   ldl       obesity   tobacco  
## [8] typea    
## 
## Root node error: 160/462 = 0.35
## 
## n= 462 
## 
##        CP nsplit rel error xerror  xstd
## 1  0.1250      0      1.00   1.00 0.064
## 2  0.1000      1      0.88   0.96 0.063
## 3  0.0625      2      0.78   0.84 0.061
## 4  0.0250      3      0.71   0.76 0.059
## 5  0.0188      5      0.66   0.79 0.060
## 6  0.0125      7      0.62   0.83 0.061
## 7  0.0094      8      0.61   0.86 0.061
## 8  0.0063     16      0.53   0.94 0.063
## 9  0.0031     19      0.51   0.97 0.063
## 10 0.0000     21      0.51   0.99 0.064
# From the table the tree model with the best cross validation error
# (xerror) has 4 end nodes and 3 splits. We can now prune() this tree back
# to this point by setting a stopping value cp=0.02.
pruned.tree.mod2 = prune(tree.mod2, cp = 0.02)
plot(pruned.tree.mod2, margin = 0.05)
text(pruned.tree.mod2, use.n = TRUE)

plot of chunk rpartXval

This is probably the best generalisation we can build from this data.

Exercises 6A

  1. Find the commands to extract the a) residuals b) coefficients from the fit
  2. What is the equation of the fit given those coefficients?