EPI 288 Lecture 5: Classification and Regression Trees (CART)

References

The specification of a loss matrix for the rpart package is very poorly documented in the help. See the third link above for it. Loading mvpart may transpose the direction of the loss matrix (columns for reality and rows for (mis)classification)?

Load chest pain dataset

library(sas7bdat)
cpain <- read.sas7bdat("./cpdata06.sas7bdat")
names(cpain) <- tolower(names(cpain))

Load mvpart package

library(mvpart)

The mvpart package is an extension to the rpart package and is dependent on rpart.

Chest pain dataset: Prediction of admission

## Variable manipulation
cpain <- within(cpain, {
    admit <- factor(admit, 0:1, c("Not admitted", "Admitted"))
    mi    <- factor(mi,    0:1, c("No", "MI"))
})

## Perform rpart
rpart.cpain <- rpart(
    formula = admit ~ systol + diastol + qual1 + qual2 + qual3 + qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 + asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur + smok_pst + smok_nev + precath + preptca + precabg + repro1 + repro2 + repro3 + age + radp + funnyp + should + remiowtb + begin + lengthp + stel + stdep + charlson + bcbs + hmo + medicare + medicaid + selfpay + white + black + hispan + male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 + lengthp30 + lengthp15,

    data       = cpain,
    ## weights = ,          # optional case weights. Not yet supported in current version
    na.action  = na.rpart,  # By default, deleted if outcome is missing, kept if predictors are missing
    method     = "class",   # Classification for factor
    model      = FALSE,
    x          = FALSE,
    y          = TRUE,

    ## optional parameters for the splitting function
    parms = list(
        ## prior = c(0.5, 0.5), # prior probabilities. Defaults to observed data frequencies.
        loss = matrix(c(0,1,1,0), ncol = 2), # loss matrix. Penalize false positive or negative more heavily
        split = "gini"      # gini or information
    ),

    ## rpart algorithm options (These are defaults, thus the whole control argument can be omitted)
    control = rpart.control(
        minsplit       = 20,   # minimum number of observations required before split
        minbucket      = 20/3, # minimum number of observations in any terminal node. deault = minsplit/3
        cp             = 0.01, # complexity parameter used as the stopping rule
        maxcompete     = 4,    # number of competitor splits retained in the output
        maxsurrogate   = 5,    # number of surrogate splits retained in the output
        usesurrogate   = 2,    # how to use surrogates in the splitting process
        xval           = 10,   # number of cross-validations
        surrogatestyle = 0,    # controls the selection of a best surrogate
        maxdepth       = 30)   # maximum depth of any node of the final tree
    ##,
    ## cost = c() # a vector of cost for each variable
    )

Optional parameters that can be given to parms

   parms:
          For classification splitting, the list can contain any of:
          the vector of prior probabilities (component ‘prior’), the
          loss matrix (component ‘loss’) or the splitting index
          (component ‘split’).  The priors must be positive and sum to
          1.  The loss matrix must have zeros on the diagonal and
          positive off-diagonal elements.  The splitting index can be
          ‘gini’ or ‘information’.  The default priors are proportional
          to the data counts, the losses default to 1, and the split
          defaults to ‘gini’.

Specification of a loss matrix

See the link below for loss matrix specification. It is poorly documented in the help.

http://www.louisaslett.com/Courses/Data_Mining_09-10/ST4003-Lab4-New_Tree_Data_Set_and_Loss_Matrices.pdf

The rows represent reality and the columns represent the (mis)classification.

## level 1 (Not admit) to level 2 (Admit) misclassification 100 times more serious
matrix(c(0, 1, 100, 0), ncol = 2)
     [,1] [,2]
[1,]    0  100
[2,]    1    0

## level 2 (Admit) to level 1 (Not admit) misclassification 100 times more serious
matrix(c(0, 100, 1, 0), ncol = 2)
     [,1] [,2]
[1,]    0    1
[2,]  100    0

See results

rpart.cpain
n= 4373 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 4373 1685 Admitted (0.38531900 0.61468100)  
   2) remiowtb< 0.5 3209 1601 Admitted (0.49890932 0.50109068)  
     4) stdep< 0.5 2751 1176 Not admitted (0.57251908 0.42748092)  
       8) qual1< 0.5 1453  427 Not admitted (0.70612526 0.29387474)  
        16) age< 55.5 864  170 Not admitted (0.80324074 0.19675926) *
        17) age>=55.5 589  257 Not admitted (0.56366723 0.43633277)  
          34) charlson< 0.5 268   85 Not admitted (0.68283582 0.31716418) *
          35) charlson>=0.5 321  149 Admitted (0.46417445 0.53582555)  
            70) lengthp>=475 86   27 Not admitted (0.68604651 0.31395349) *
            71) lengthp< 475 235   90 Admitted (0.38297872 0.61702128) *
       9) qual1>=0.5 1298  549 Admitted (0.42295840 0.57704160)  
        18) age< 49.5 557  241 Not admitted (0.56732496 0.43267504)  
          36) male< 0.5 312   99 Not admitted (0.68269231 0.31730769) *
          37) male>=0.5 245  103 Admitted (0.42040816 0.57959184) *
        19) age>=49.5 741  233 Admitted (0.31443995 0.68556005) *
     5) stdep>=0.5 458   26 Admitted (0.05676856 0.94323144) *
   3) remiowtb>=0.5 1164   84 Admitted (0.07216495 0.92783505) *

## display cp table
printcp(rpart.cpain)

Classification tree:
rpart(formula = admit ~ systol + diastol + qual1 + qual2 + qual3 + 
    qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 + 
    asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur + 
    smok_pst + smok_nev + precath + preptca + precabg + repro1 + 
    repro2 + repro3 + age + radp + funnyp + should + remiowtb + 
    begin + lengthp + stel + stdep + charlson + bcbs + hmo + 
    medicare + medicaid + selfpay + white + black + hispan + 
    male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 + 
    lengthp30 + lengthp15, data = cpain, na.action = na.rpart, 
    method = "class", model = FALSE, x = FALSE, y = TRUE, parms = list(loss = matrix(c(0, 
        1, 1, 0), ncol = 2), split = "gini"), control = rpart.control(minsplit = 20, 
        minbucket = 20/3, cp = 0.01, maxcompete = 4, maxsurrogate = 5, 
        usesurrogate = 2, xval = 10, surrogatestyle = 0, maxdepth = 30))

Variables actually used in tree construction:
[1] age      charlson lengthp  male     qual1    remiowtb stdep   

Root node error: 1685/4373 = 0.38532

n= 4373 

        CP nsplit rel error  xerror     xstd
1 0.118398      0   1.00000 1.00000 0.019100
2 0.044510      3   0.64451 0.64748 0.016982
3 0.023145      4   0.60000 0.61187 0.016659
4 0.010880      5   0.57685 0.59644 0.016511
5 0.010000      8   0.54421 0.58457 0.016395

## plot cross-validation results
plotcp(rpart.cpain)

plot of chunk unnamed-chunk-6


## plot approximate R-squared and relative error for different splits
layout(matrix(1:2, ncol = 2))
rsq.rpart(rpart.cpain)

Classification tree:
rpart(formula = admit ~ systol + diastol + qual1 + qual2 + qual3 + 
    qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 + 
    asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur + 
    smok_pst + smok_nev + precath + preptca + precabg + repro1 + 
    repro2 + repro3 + age + radp + funnyp + should + remiowtb + 
    begin + lengthp + stel + stdep + charlson + bcbs + hmo + 
    medicare + medicaid + selfpay + white + black + hispan + 
    male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 + 
    lengthp30 + lengthp15, data = cpain, na.action = na.rpart, 
    method = "class", model = FALSE, x = FALSE, y = TRUE, parms = list(loss = matrix(c(0, 
        1, 1, 0), ncol = 2), split = "gini"), control = rpart.control(minsplit = 20, 
        minbucket = 20/3, cp = 0.01, maxcompete = 4, maxsurrogate = 5, 
        usesurrogate = 2, xval = 10, surrogatestyle = 0, maxdepth = 30))

Variables actually used in tree construction:
[1] age      charlson lengthp  male     qual1    remiowtb stdep   

Root node error: 1685/4373 = 0.38532

n= 4373 

        CP nsplit rel error  xerror     xstd
1 0.118398      0   1.00000 1.00000 0.019100
2 0.044510      3   0.64451 0.64748 0.016982
3 0.023145      4   0.60000 0.61187 0.016659
4 0.010880      5   0.57685 0.59644 0.016511
5 0.010000      8   0.54421 0.58457 0.016395
May not be applicable for this method

plot of chunk unnamed-chunk-6

layout(1)

## detailed results including surrogate splits
summary(rpart.cpain,
        cp = 0.15) # trim nodes with a complexity of less than ‘cp’
Call:
rpart(formula = admit ~ systol + diastol + qual1 + qual2 + qual3 + 
    qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 + 
    asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur + 
    smok_pst + smok_nev + precath + preptca + precabg + repro1 + 
    repro2 + repro3 + age + radp + funnyp + should + remiowtb + 
    begin + lengthp + stel + stdep + charlson + bcbs + hmo + 
    medicare + medicaid + selfpay + white + black + hispan + 
    male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 + 
    lengthp30 + lengthp15, data = cpain, na.action = na.rpart, 
    method = "class", model = FALSE, x = FALSE, y = TRUE, parms = list(loss = matrix(c(0, 
        1, 1, 0), ncol = 2), split = "gini"), control = rpart.control(minsplit = 20, 
        minbucket = 20/3, cp = 0.01, maxcompete = 4, maxsurrogate = 5, 
        usesurrogate = 2, xval = 10, surrogatestyle = 0, maxdepth = 30))
  n= 4373 

          CP nsplit rel error    xerror       xstd
1 0.11839763      0 1.0000000 1.0000000 0.01909963
2 0.04451039      3 0.6445104 0.6474777 0.01698212
3 0.02314540      4 0.6000000 0.6118694 0.01665877
4 0.01088032      5 0.5768546 0.5964392 0.01651123
5 0.01000000      8 0.5442136 0.5845697 0.01639458
Error: incorrect number of dimensions

See results as a tree

## Define a plotting function with decent defaults
plot.rpart.obj <- function(rpart.obj, font.size = 0.8) {
    ## plot decision tree
    plot(rpart.obj,
         uniform   = T,    # if ‘TRUE’, uniform vertical spacing of the nodes is used
         branch    = 1,    # controls the shape of the branches from parent to child node
         compress  = F,    # if ‘FALSE’, the leaf nodes will be at the horizontal plot
         nspace    = 0.1,
         margin    = 0.1, # an extra fraction of white space to leave around the borders
         minbranch = 0.3)  # set the minimum length for a branch

    ## Add text
    text(x      = rpart.obj,   #
         splits = T,           # If tree are labeled with the criterion for the split
         all    = T,           # If ‘TRUE’, all nodes are labeled, otherwise just terminal nodes
         use.n  = T,           # Use numbers to annotate
         cex    = font.size)   # Font size
}
## Plot
plot.rpart.obj(rpart.cpain, 1)

plot of chunk unnamed-chunk-8

Prune the tree at different complexity parameter values

## Pruned at complexity parameter of 0.02n
plot.rpart.obj(prune(tree = rpart.cpain, cp = 0.02), 1)

plot of chunk unnamed-chunk-9

## Pruned at complexity parameter of 0.03
plot.rpart.obj(prune(tree = rpart.cpain, cp = 0.03), 1)

plot of chunk unnamed-chunk-9

## Pruned at complexity parameter of 0.1
plot.rpart.obj(prune(tree = rpart.cpain, cp = 0.1), 1)

plot of chunk unnamed-chunk-9

MI example

## Create formula
form <- "mi ~ systol + diastol + qual1 + qual2 + qual3 + qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 + asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur + smok_pst + smok_nev + precath + preptca + precabg + repro1 + repro2 + repro3 + age + radp + funnyp + should + remiowtb + begin + lengthp + stel + stdep + charlson + bcbs + hmo + medicare + medicaid + selfpay + white + black + hispan + male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 + lengthp30 + lengthp15"

Same weight for both misclassification

lmat <- matrix(c(0,1,1,0), ncol = 2)
lmat
     [,1] [,2]
[1,]    0    1
[2,]    1    0

## Perform rpart (defaults are omitted)
rpart.cpain <- rpart(
    formula = as.formula(form),
    data       = cpain,

    ## optional parameters for the splitting function
    parms = list(
        loss = lmat    # loss matrix. Penalize false positive or negative more heavily
    )
    )
## display cp table
printcp(rpart.cpain)

Classification tree:
rpart(formula = as.formula(form), data = cpain, parms = list(loss = lmat))

Variables actually used in tree construction:
[1] age     begin   black   lengthp past1   stel   

Root node error: 367/4373 = 0.083924

n= 4373 

        CP nsplit rel error  xerror     xstd
1 0.076294      0   1.00000 1.00000 0.049961
2 0.043597      2   0.84741 0.85559 0.046518
3 0.013624      3   0.80381 0.82289 0.045688
4 0.010899      5   0.77657 0.84196 0.046174
5 0.010000      7   0.75477 0.83651 0.046036
## Plot
plot.rpart.obj(rpart.cpain)

plot of chunk unnamed-chunk-11

Level 2 to Level 1 (MI to no MI) misclassification is 15 times more serious

lmat <- matrix(c(0,1,15,0), ncol = 2)
lmat
     [,1] [,2]
[1,]    0   15
[2,]    1    0

## Perform rpart (defaults are omitted)
rpart.cpain <- rpart(
    formula = as.formula(form),
    data       = cpain,

    ## optional parameters for the splitting function
    parms = list(
        loss = lmat    # loss matrix. Penalize false positive or negative more heavily
    )
    )

## display cp table
printcp(rpart.cpain)

Classification tree:
rpart(formula = as.formula(form), data = cpain, parms = list(loss = lmat))

Variables actually used in tree construction:
[1] age     hichol  lengthp male    qual4   stdep   stel   

Root node error: 4006/4373 = 0.91608

n= 4373 

        CP nsplit rel error  xerror      xstd
1 0.216800      0   1.00000 1.00000 0.0045771
2 0.020220      2   0.56640 0.56640 0.0371025
3 0.014229      3   0.54618 0.57139 0.0367474
4 0.013729      6   0.50349 0.54843 0.0356652
5 0.010734      7   0.48977 0.53495 0.0347355
6 0.010000      8   0.47903 0.53470 0.0347355
## Plot
plot.rpart.obj(rpart.cpain)

plot of chunk unnamed-chunk-12

Now it is afraid of false negatives, and fewer MIs (figure right to / ) among the people classified as No.

Using partykit package

The partykit package gives more decent output by default.

library(partykit)
## Default output
rparty.cpain <- as.party(rpart.cpain)
rparty.cpain

Model formula:
mi ~ systol + diastol + qual1 + qual2 + qual3 + qual4 + qual5 + 
    qual6 + qual7 + asso1 + asso2 + asso3 + asso4 + asso5 + past1 + 
    past2 + htn + diab + hichol + famhx + smok_cur + smok_pst + 
    smok_nev + precath + preptca + precabg + repro1 + repro2 + 
    repro3 + age + radp + funnyp + should + remiowtb + begin + 
    lengthp + stel + stdep + charlson + bcbs + hmo + medicare + 
    medicaid + selfpay + white + black + hispan + male + begin12 + 
    begin24 + begin48 + lengthp120 + lengthp60 + lengthp30 + 
    lengthp15

Fitted party:
[1] root
|   [2] stel < 0.5
|   |   [3] stdep < 0.5
|   |   |   [4] age < 43.5: No (n = 857, err = 1.1%)
|   |   |   [5] age >= 43.5
|   |   |   |   [6] male < 0.5: No (n = 1358, err = 2.2%)
|   |   |   |   [7] male >= 0.5
|   |   |   |   |   [8] hichol < 0.5
|   |   |   |   |   |   [9] qual4 < 0.5: No (n = 664, err = 2.9%)
|   |   |   |   |   |   [10] qual4 >= 0.5: No (n = 101, err = 8.9%)
|   |   |   |   |   [11] hichol >= 0.5
|   |   |   |   |   |   [12] lengthp < 9: No (n = 55, err = 0.0%)
|   |   |   |   |   |   [13] lengthp >= 9: No (n = 302, err = 10.9%)
|   |   [14] stdep >= 0.5
|   |   |   [15] lengthp < 27.5: No (n = 193, err = 3.6%)
|   |   |   [16] lengthp >= 27.5: No (n = 571, err = 20.1%)
|   [17] stel >= 0.5: MI (n = 272, err = 46.7%)

Number of inner nodes:    8
Number of terminal nodes: 9
## Output by rpart
rpart.cpain
n= 4373 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 4373 4006 MI (0.91607592 0.08392408)  
   2) stel< 0.5 4101 3330 No (0.94586686 0.05413314)  
     4) stdep< 0.5 3337 1500 No (0.97003296 0.02996704)  
       8) age< 43.5 857  135 No (0.98949825 0.01050175) *
       9) age>=43.5 2480 1365 No (0.96330645 0.03669355)  
        18) male< 0.5 1358  450 No (0.97790869 0.02209131) *
        19) male>=0.5 1122  915 No (0.94563280 0.05436720)  
          38) hichol< 0.5 765  420 No (0.96339869 0.03660131)  
            76) qual4< 0.5 664  285 No (0.97138554 0.02861446) *
            77) qual4>=0.5 101   92 MI (0.91089109 0.08910891) *
          39) hichol>=0.5 357  324 MI (0.90756303 0.09243697)  
            78) lengthp< 9 55    0 No (1.00000000 0.00000000) *
            79) lengthp>=9 302  269 MI (0.89072848 0.10927152) *
     5) stdep>=0.5 764  642 MI (0.84031414 0.15968586)  
      10) lengthp< 27.5 193  105 No (0.96373057 0.03626943) *
      11) lengthp>=27.5 571  456 MI (0.79859895 0.20140105) *
   3) stel>=0.5 272  127 MI (0.46691176 0.53308824) *
## Default plot
plot(rparty.cpain)

plot of chunk unnamed-chunk-13