Goal

Heart Disease

Let’s do an example with the heartdisease dataset in the FFTrees package:

library(FFTrees)  # For the heartdisease data
library(partykit)
library(rpart)

Here is how the heartdisease data look:

str(heartdisease)
## 'data.frame':    303 obs. of  14 variables:
##  $ age      : num  63 67 67 37 41 56 62 57 63 53 ...
##  $ sex      : num  1 1 1 1 0 1 0 0 1 1 ...
##  $ cp       : chr  "ta" "a" "a" "np" ...
##  $ trestbps : num  145 160 120 130 130 120 140 120 130 140 ...
##  $ chol     : num  233 286 229 250 204 236 268 354 254 203 ...
##  $ fbs      : num  1 0 0 0 0 0 0 0 0 1 ...
##  $ restecg  : chr  "hypertrophy" "hypertrophy" "hypertrophy" "normal" ...
##  $ thalach  : num  150 108 129 187 172 178 160 163 147 155 ...
##  $ exang    : num  0 1 1 0 0 0 0 1 0 1 ...
##  $ oldpeak  : num  2.3 1.5 2.6 3.5 1.4 0.8 3.6 0.6 1.4 3.1 ...
##  $ slope    : chr  "down" "flat" "flat" "down" ...
##  $ ca       : num  0 3 2 0 0 0 2 0 1 0 ...
##  $ thal     : chr  "fd" "normal" "rd" "normal" ...
##  $ diagnosis: num  0 1 1 0 0 0 1 0 1 1 ...
head(heartdisease)
##   age sex cp trestbps chol fbs     restecg thalach exang oldpeak slope ca
## 1  63   1 ta      145  233   1 hypertrophy     150     0     2.3  down  0
## 2  67   1  a      160  286   0 hypertrophy     108     1     1.5  flat  3
## 3  67   1  a      120  229   0 hypertrophy     129     1     2.6  flat  2
## 4  37   1 np      130  250   0      normal     187     0     3.5  down  0
## 5  41   0 aa      130  204   0 hypertrophy     172     0     1.4    up  0
## 6  56   1 aa      120  236   0      normal     178     0     0.8    up  0
##     thal diagnosis
## 1     fd         0
## 2 normal         1
## 3     rd         1
## 4 normal         0
## 5 normal         0
## 6 normal         0

rpart

First I’ll create an rpart object

# Create an rpart model for heartdisease diagnosis

heart.rpart <- rpart::rpart(diagnosis ~., 
                            data = heartdisease, 
                            method = "class")

Here is the resulting tree:

plot(heart.rpart)
text(heart.rpart)

Now here is the frame which tells me the structure of the tree.

heart.rpart$frame[,1:5]
##        var   n  wt dev yval
## 1     thal 303 303 139    1
## 2       cp 168 168  38    1
## 4   <leaf> 101 101   9    1
## 5       ca  67  67  29    1
## 10  <leaf>  40  40   9    1
## 11  <leaf>  27  27   7    2
## 3       cp 135 135  34    2
## 6  thalach  45  45  21    1
## 12  <leaf>  33  33  11    1
## 13  <leaf>  12  12   2    2
## 7   <leaf>  90  90  10    2

Here is a vector of the terminal node at which each case is classified:

heart.rpart$where[1:25]
##  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 
##  9  6 11  3  3  3  6  5 11 11 11  3 10  9  9  3  9  5  3  3  5  5  3  9 11
# Table of values
table(heart.rpart$where)
## 
##   3   5   6   9  10  11 
## 101  40  27  33  12  90

Convert to party object

I can also show the result as a party object (thanks Heidi Seibold)

# Convert to party object
heart.party <- as.party(heart.rpart)
plot(heart.party)

Here is the printout

heart.party
## 
## Model formula:
## diagnosis ~ age + sex + cp + trestbps + chol + fbs + restecg + 
##     thalach + exang + oldpeak + slope + ca + thal
## 
## Fitted party:
## [1] root
## |   [2] thal in normal
## |   |   [3] cp in aa, np: 0.089 (n = 101, err = 8.2)
## |   |   [4] cp in a, ta
## |   |   |   [5] ca < 0.5: 0.225 (n = 40, err = 7.0)
## |   |   |   [6] ca >= 0.5: 0.741 (n = 27, err = 5.2)
## |   [7] thal in fd, rd
## |   |   [8] cp in aa, np, ta
## |   |   |   [9] thalach >= 143: 0.333 (n = 33, err = 7.3)
## |   |   |   [10] thalach < 143: 0.833 (n = 12, err = 1.7)
## |   |   [11] cp in a: 0.889 (n = 90, err = 8.9)
## 
## Number of inner nodes:    5
## Number of terminal nodes: 6

From these outputs I can easily visually see how many nodes are above each terminal node. For example, node 3 has two nodes above it (2 and 1), while node 5 has three nodes above it (4, 2 and 1). Going through all terminal nodes, I get the following table:

Nodes above each terminal node

Node.Terminal = Index of a terminal node. Above = Number of nodes above terminal node
Node.Terminal Above
3 2
5 3
6 3
9 3
10 3
11 2

Question: How can I directly extract this information from an rpart or party object?