The specification of a loss matrix for the rpart package is very poorly documented in the help. See the third link above for it. Loading mvpart may transpose the direction of the loss matrix (columns for reality and rows for (mis)classification)?
library(sas7bdat)
cpain <- read.sas7bdat("./cpdata06.sas7bdat")
names(cpain) <- tolower(names(cpain))
library(mvpart)
The mvpart package is an extension to the rpart package and is dependent on rpart.
## Variable manipulation
cpain <- within(cpain, {
admit <- factor(admit, 0:1, c("Not admitted", "Admitted"))
mi <- factor(mi, 0:1, c("No", "MI"))
})
## Perform rpart
rpart.cpain <- rpart(
formula = admit ~ systol + diastol + qual1 + qual2 + qual3 + qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 + asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur + smok_pst + smok_nev + precath + preptca + precabg + repro1 + repro2 + repro3 + age + radp + funnyp + should + remiowtb + begin + lengthp + stel + stdep + charlson + bcbs + hmo + medicare + medicaid + selfpay + white + black + hispan + male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 + lengthp30 + lengthp15,
data = cpain,
## weights = , # optional case weights. Not yet supported in current version
na.action = na.rpart, # By default, deleted if outcome is missing, kept if predictors are missing
method = "class", # Classification for factor
model = FALSE,
x = FALSE,
y = TRUE,
## optional parameters for the splitting function
parms = list(
## prior = c(0.5, 0.5), # prior probabilities. Defaults to observed data frequencies.
loss = matrix(c(0,1,1,0), ncol = 2), # loss matrix. Penalize false positive or negative more heavily
split = "gini" # gini or information
),
## rpart algorithm options (These are defaults, thus the whole control argument can be omitted)
control = rpart.control(
minsplit = 20, # minimum number of observations required before split
minbucket = 20/3, # minimum number of observations in any terminal node. deault = minsplit/3
cp = 0.01, # complexity parameter used as the stopping rule
maxcompete = 4, # number of competitor splits retained in the output
maxsurrogate = 5, # number of surrogate splits retained in the output
usesurrogate = 2, # how to use surrogates in the splitting process
xval = 10, # number of cross-validations
surrogatestyle = 0, # controls the selection of a best surrogate
maxdepth = 30) # maximum depth of any node of the final tree
##,
## cost = c() # a vector of cost for each variable
)
Optional parameters that can be given to parms
parms:
For classification splitting, the list can contain any of:
the vector of prior probabilities (component ‘prior’), the
loss matrix (component ‘loss’) or the splitting index
(component ‘split’). The priors must be positive and sum to
1. The loss matrix must have zeros on the diagonal and
positive off-diagonal elements. The splitting index can be
‘gini’ or ‘information’. The default priors are proportional
to the data counts, the losses default to 1, and the split
defaults to ‘gini’.
Specification of a loss matrix
See the link below for loss matrix specification. It is poorly documented in the help.
The rows represent reality and the columns represent the (mis)classification.
## level 1 (Not admit) to level 2 (Admit) misclassification 100 times more serious
matrix(c(0, 1, 100, 0), ncol = 2)
[,1] [,2]
[1,] 0 100
[2,] 1 0
## level 2 (Admit) to level 1 (Not admit) misclassification 100 times more serious
matrix(c(0, 100, 1, 0), ncol = 2)
[,1] [,2]
[1,] 0 1
[2,] 100 0
rpart.cpain
n= 4373
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 4373 1685 Admitted (0.38531900 0.61468100)
2) remiowtb< 0.5 3209 1601 Admitted (0.49890932 0.50109068)
4) stdep< 0.5 2751 1176 Not admitted (0.57251908 0.42748092)
8) qual1< 0.5 1453 427 Not admitted (0.70612526 0.29387474)
16) age< 55.5 864 170 Not admitted (0.80324074 0.19675926) *
17) age>=55.5 589 257 Not admitted (0.56366723 0.43633277)
34) charlson< 0.5 268 85 Not admitted (0.68283582 0.31716418) *
35) charlson>=0.5 321 149 Admitted (0.46417445 0.53582555)
70) lengthp>=475 86 27 Not admitted (0.68604651 0.31395349) *
71) lengthp< 475 235 90 Admitted (0.38297872 0.61702128) *
9) qual1>=0.5 1298 549 Admitted (0.42295840 0.57704160)
18) age< 49.5 557 241 Not admitted (0.56732496 0.43267504)
36) male< 0.5 312 99 Not admitted (0.68269231 0.31730769) *
37) male>=0.5 245 103 Admitted (0.42040816 0.57959184) *
19) age>=49.5 741 233 Admitted (0.31443995 0.68556005) *
5) stdep>=0.5 458 26 Admitted (0.05676856 0.94323144) *
3) remiowtb>=0.5 1164 84 Admitted (0.07216495 0.92783505) *
## display cp table
printcp(rpart.cpain)
Classification tree:
rpart(formula = admit ~ systol + diastol + qual1 + qual2 + qual3 +
qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 +
asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur +
smok_pst + smok_nev + precath + preptca + precabg + repro1 +
repro2 + repro3 + age + radp + funnyp + should + remiowtb +
begin + lengthp + stel + stdep + charlson + bcbs + hmo +
medicare + medicaid + selfpay + white + black + hispan +
male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 +
lengthp30 + lengthp15, data = cpain, na.action = na.rpart,
method = "class", model = FALSE, x = FALSE, y = TRUE, parms = list(loss = matrix(c(0,
1, 1, 0), ncol = 2), split = "gini"), control = rpart.control(minsplit = 20,
minbucket = 20/3, cp = 0.01, maxcompete = 4, maxsurrogate = 5,
usesurrogate = 2, xval = 10, surrogatestyle = 0, maxdepth = 30))
Variables actually used in tree construction:
[1] age charlson lengthp male qual1 remiowtb stdep
Root node error: 1685/4373 = 0.38532
n= 4373
CP nsplit rel error xerror xstd
1 0.118398 0 1.00000 1.00000 0.019100
2 0.044510 3 0.64451 0.64748 0.016982
3 0.023145 4 0.60000 0.61187 0.016659
4 0.010880 5 0.57685 0.59644 0.016511
5 0.010000 8 0.54421 0.58457 0.016395
## plot cross-validation results
plotcp(rpart.cpain)
## plot approximate R-squared and relative error for different splits
layout(matrix(1:2, ncol = 2))
rsq.rpart(rpart.cpain)
Classification tree:
rpart(formula = admit ~ systol + diastol + qual1 + qual2 + qual3 +
qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 +
asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur +
smok_pst + smok_nev + precath + preptca + precabg + repro1 +
repro2 + repro3 + age + radp + funnyp + should + remiowtb +
begin + lengthp + stel + stdep + charlson + bcbs + hmo +
medicare + medicaid + selfpay + white + black + hispan +
male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 +
lengthp30 + lengthp15, data = cpain, na.action = na.rpart,
method = "class", model = FALSE, x = FALSE, y = TRUE, parms = list(loss = matrix(c(0,
1, 1, 0), ncol = 2), split = "gini"), control = rpart.control(minsplit = 20,
minbucket = 20/3, cp = 0.01, maxcompete = 4, maxsurrogate = 5,
usesurrogate = 2, xval = 10, surrogatestyle = 0, maxdepth = 30))
Variables actually used in tree construction:
[1] age charlson lengthp male qual1 remiowtb stdep
Root node error: 1685/4373 = 0.38532
n= 4373
CP nsplit rel error xerror xstd
1 0.118398 0 1.00000 1.00000 0.019100
2 0.044510 3 0.64451 0.64748 0.016982
3 0.023145 4 0.60000 0.61187 0.016659
4 0.010880 5 0.57685 0.59644 0.016511
5 0.010000 8 0.54421 0.58457 0.016395
May not be applicable for this method
layout(1)
## detailed results including surrogate splits
summary(rpart.cpain,
cp = 0.15) # trim nodes with a complexity of less than ‘cp’
Call:
rpart(formula = admit ~ systol + diastol + qual1 + qual2 + qual3 +
qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 +
asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur +
smok_pst + smok_nev + precath + preptca + precabg + repro1 +
repro2 + repro3 + age + radp + funnyp + should + remiowtb +
begin + lengthp + stel + stdep + charlson + bcbs + hmo +
medicare + medicaid + selfpay + white + black + hispan +
male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 +
lengthp30 + lengthp15, data = cpain, na.action = na.rpart,
method = "class", model = FALSE, x = FALSE, y = TRUE, parms = list(loss = matrix(c(0,
1, 1, 0), ncol = 2), split = "gini"), control = rpart.control(minsplit = 20,
minbucket = 20/3, cp = 0.01, maxcompete = 4, maxsurrogate = 5,
usesurrogate = 2, xval = 10, surrogatestyle = 0, maxdepth = 30))
n= 4373
CP nsplit rel error xerror xstd
1 0.11839763 0 1.0000000 1.0000000 0.01909963
2 0.04451039 3 0.6445104 0.6474777 0.01698212
3 0.02314540 4 0.6000000 0.6118694 0.01665877
4 0.01088032 5 0.5768546 0.5964392 0.01651123
5 0.01000000 8 0.5442136 0.5845697 0.01639458
Error: incorrect number of dimensions
## Define a plotting function with decent defaults
plot.rpart.obj <- function(rpart.obj, font.size = 0.8) {
## plot decision tree
plot(rpart.obj,
uniform = T, # if ‘TRUE’, uniform vertical spacing of the nodes is used
branch = 1, # controls the shape of the branches from parent to child node
compress = F, # if ‘FALSE’, the leaf nodes will be at the horizontal plot
nspace = 0.1,
margin = 0.1, # an extra fraction of white space to leave around the borders
minbranch = 0.3) # set the minimum length for a branch
## Add text
text(x = rpart.obj, #
splits = T, # If tree are labeled with the criterion for the split
all = T, # If ‘TRUE’, all nodes are labeled, otherwise just terminal nodes
use.n = T, # Use numbers to annotate
cex = font.size) # Font size
}
## Plot
plot.rpart.obj(rpart.cpain, 1)
## Pruned at complexity parameter of 0.02n
plot.rpart.obj(prune(tree = rpart.cpain, cp = 0.02), 1)
## Pruned at complexity parameter of 0.03
plot.rpart.obj(prune(tree = rpart.cpain, cp = 0.03), 1)
## Pruned at complexity parameter of 0.1
plot.rpart.obj(prune(tree = rpart.cpain, cp = 0.1), 1)
## Create formula
form <- "mi ~ systol + diastol + qual1 + qual2 + qual3 + qual4 + qual5 + qual6 + qual7 + asso1 + asso2 + asso3 + asso4 + asso5 + past1 + past2 + htn + diab + hichol + famhx + smok_cur + smok_pst + smok_nev + precath + preptca + precabg + repro1 + repro2 + repro3 + age + radp + funnyp + should + remiowtb + begin + lengthp + stel + stdep + charlson + bcbs + hmo + medicare + medicaid + selfpay + white + black + hispan + male + begin12 + begin24 + begin48 + lengthp120 + lengthp60 + lengthp30 + lengthp15"
Same weight for both misclassification
lmat <- matrix(c(0,1,1,0), ncol = 2)
lmat
[,1] [,2]
[1,] 0 1
[2,] 1 0
## Perform rpart (defaults are omitted)
rpart.cpain <- rpart(
formula = as.formula(form),
data = cpain,
## optional parameters for the splitting function
parms = list(
loss = lmat # loss matrix. Penalize false positive or negative more heavily
)
)
## display cp table
printcp(rpart.cpain)
Classification tree:
rpart(formula = as.formula(form), data = cpain, parms = list(loss = lmat))
Variables actually used in tree construction:
[1] age begin black lengthp past1 stel
Root node error: 367/4373 = 0.083924
n= 4373
CP nsplit rel error xerror xstd
1 0.076294 0 1.00000 1.00000 0.049961
2 0.043597 2 0.84741 0.85559 0.046518
3 0.013624 3 0.80381 0.82289 0.045688
4 0.010899 5 0.77657 0.84196 0.046174
5 0.010000 7 0.75477 0.83651 0.046036
## Plot
plot.rpart.obj(rpart.cpain)
Level 2 to Level 1 (MI to no MI) misclassification is 15 times more serious
lmat <- matrix(c(0,1,15,0), ncol = 2)
lmat
[,1] [,2]
[1,] 0 15
[2,] 1 0
## Perform rpart (defaults are omitted)
rpart.cpain <- rpart(
formula = as.formula(form),
data = cpain,
## optional parameters for the splitting function
parms = list(
loss = lmat # loss matrix. Penalize false positive or negative more heavily
)
)
## display cp table
printcp(rpart.cpain)
Classification tree:
rpart(formula = as.formula(form), data = cpain, parms = list(loss = lmat))
Variables actually used in tree construction:
[1] age hichol lengthp male qual4 stdep stel
Root node error: 4006/4373 = 0.91608
n= 4373
CP nsplit rel error xerror xstd
1 0.216800 0 1.00000 1.00000 0.0045771
2 0.020220 2 0.56640 0.56640 0.0371025
3 0.014229 3 0.54618 0.57139 0.0367474
4 0.013729 6 0.50349 0.54843 0.0356652
5 0.010734 7 0.48977 0.53495 0.0347355
6 0.010000 8 0.47903 0.53470 0.0347355
## Plot
plot.rpart.obj(rpart.cpain)
Now it is afraid of false negatives, and fewer MIs (figure right to / ) among the people classified as No.
The partykit package gives more decent output by default.
library(partykit)
## Default output
rparty.cpain <- as.party(rpart.cpain)
rparty.cpain
Model formula:
mi ~ systol + diastol + qual1 + qual2 + qual3 + qual4 + qual5 +
qual6 + qual7 + asso1 + asso2 + asso3 + asso4 + asso5 + past1 +
past2 + htn + diab + hichol + famhx + smok_cur + smok_pst +
smok_nev + precath + preptca + precabg + repro1 + repro2 +
repro3 + age + radp + funnyp + should + remiowtb + begin +
lengthp + stel + stdep + charlson + bcbs + hmo + medicare +
medicaid + selfpay + white + black + hispan + male + begin12 +
begin24 + begin48 + lengthp120 + lengthp60 + lengthp30 +
lengthp15
Fitted party:
[1] root
| [2] stel < 0.5
| | [3] stdep < 0.5
| | | [4] age < 43.5: No (n = 857, err = 1.1%)
| | | [5] age >= 43.5
| | | | [6] male < 0.5: No (n = 1358, err = 2.2%)
| | | | [7] male >= 0.5
| | | | | [8] hichol < 0.5
| | | | | | [9] qual4 < 0.5: No (n = 664, err = 2.9%)
| | | | | | [10] qual4 >= 0.5: No (n = 101, err = 8.9%)
| | | | | [11] hichol >= 0.5
| | | | | | [12] lengthp < 9: No (n = 55, err = 0.0%)
| | | | | | [13] lengthp >= 9: No (n = 302, err = 10.9%)
| | [14] stdep >= 0.5
| | | [15] lengthp < 27.5: No (n = 193, err = 3.6%)
| | | [16] lengthp >= 27.5: No (n = 571, err = 20.1%)
| [17] stel >= 0.5: MI (n = 272, err = 46.7%)
Number of inner nodes: 8
Number of terminal nodes: 9
## Output by rpart
rpart.cpain
n= 4373
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 4373 4006 MI (0.91607592 0.08392408)
2) stel< 0.5 4101 3330 No (0.94586686 0.05413314)
4) stdep< 0.5 3337 1500 No (0.97003296 0.02996704)
8) age< 43.5 857 135 No (0.98949825 0.01050175) *
9) age>=43.5 2480 1365 No (0.96330645 0.03669355)
18) male< 0.5 1358 450 No (0.97790869 0.02209131) *
19) male>=0.5 1122 915 No (0.94563280 0.05436720)
38) hichol< 0.5 765 420 No (0.96339869 0.03660131)
76) qual4< 0.5 664 285 No (0.97138554 0.02861446) *
77) qual4>=0.5 101 92 MI (0.91089109 0.08910891) *
39) hichol>=0.5 357 324 MI (0.90756303 0.09243697)
78) lengthp< 9 55 0 No (1.00000000 0.00000000) *
79) lengthp>=9 302 269 MI (0.89072848 0.10927152) *
5) stdep>=0.5 764 642 MI (0.84031414 0.15968586)
10) lengthp< 27.5 193 105 No (0.96373057 0.03626943) *
11) lengthp>=27.5 571 456 MI (0.79859895 0.20140105) *
3) stel>=0.5 272 127 MI (0.46691176 0.53308824) *
## Default plot
plot(rparty.cpain)