library(rpart) #classification and regression trees
library(partykit) #treeplots
## Loading required package: grid
library(randomForest) #random forests
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
library(gbm) #gradient boosting
## Loading required package: survival
## Loading required package: lattice
## Loading required package: splines
## Loading required package: parallel
## Loaded gbm 2.1.3
library(caret) #tune hyper-parameters
## Loading required package: ggplot2
##
## Attaching package: 'ggplot2'
## The following object is masked from 'package:randomForest':
##
## margin
##
## Attaching package: 'caret'
## The following object is masked from 'package:survival':
##
## cluster
library(mlbench)
data(Glass)
dim(Glass)
## [1] 214 10
levels(Glass$Type)
## [1] "1" "2" "3" "5" "6" "7"
head(Glass)
## RI Na Mg Al Si K Ca Ba Fe Type
## 1 1.52101 13.64 4.49 1.10 71.78 0.06 8.75 0 0.00 1
## 2 1.51761 13.89 3.60 1.36 72.73 0.48 7.83 0 0.00 1
## 3 1.51618 13.53 3.55 1.54 72.99 0.39 7.78 0 0.00 1
## 4 1.51766 13.21 3.69 1.29 72.61 0.57 8.22 0 0.00 1
## 5 1.51742 13.27 3.62 1.24 73.08 0.55 8.07 0 0.00 1
## 6 1.51596 12.79 3.61 1.62 72.97 0.64 8.07 0 0.26 1
summary(Glass)
## RI Na Mg Al
## Min. :1.511 Min. :10.73 Min. :0.000 Min. :0.290
## 1st Qu.:1.517 1st Qu.:12.91 1st Qu.:2.115 1st Qu.:1.190
## Median :1.518 Median :13.30 Median :3.480 Median :1.360
## Mean :1.518 Mean :13.41 Mean :2.685 Mean :1.445
## 3rd Qu.:1.519 3rd Qu.:13.82 3rd Qu.:3.600 3rd Qu.:1.630
## Max. :1.534 Max. :17.38 Max. :4.490 Max. :3.500
## Si K Ca Ba
## Min. :69.81 Min. :0.0000 Min. : 5.430 Min. :0.000
## 1st Qu.:72.28 1st Qu.:0.1225 1st Qu.: 8.240 1st Qu.:0.000
## Median :72.79 Median :0.5550 Median : 8.600 Median :0.000
## Mean :72.65 Mean :0.4971 Mean : 8.957 Mean :0.175
## 3rd Qu.:73.09 3rd Qu.:0.6100 3rd Qu.: 9.172 3rd Qu.:0.000
## Max. :75.41 Max. :6.2100 Max. :16.190 Max. :3.150
## Fe Type
## Min. :0.00000 1:70
## 1st Qu.:0.00000 2:76
## Median :0.00000 3:17
## Mean :0.05701 5:13
## 3rd Qu.:0.10000 6: 9
## Max. :0.51000 7:29
str(Glass)
## 'data.frame': 214 obs. of 10 variables:
## $ RI : num 1.52 1.52 1.52 1.52 1.52 ...
## $ Na : num 13.6 13.9 13.5 13.2 13.3 ...
## $ Mg : num 4.49 3.6 3.55 3.69 3.62 3.61 3.6 3.61 3.58 3.6 ...
## $ Al : num 1.1 1.36 1.54 1.29 1.24 1.62 1.14 1.05 1.37 1.36 ...
## $ Si : num 71.8 72.7 73 72.6 73.1 ...
## $ K : num 0.06 0.48 0.39 0.57 0.55 0.64 0.58 0.57 0.56 0.57 ...
## $ Ca : num 8.75 7.83 7.78 8.22 8.07 8.07 8.17 8.24 8.3 8.4 ...
## $ Ba : num 0 0 0 0 0 0 0 0 0 0 ...
## $ Fe : num 0 0 0 0 0 0.26 0 0 0 0.11 ...
## $ Type: Factor w/ 6 levels "1","2","3","5",..: 1 1 1 1 1 1 1 1 1 1 ...
set.seed(123) #random number generator
ind = sample(2, nrow(Glass), replace=TRUE, prob=c(0.6, 0.4))
train = Glass[ind==1,] #the training data set
test = Glass[ind==2,] #the test data set
str(test[,10])
## Factor w/ 6 levels "1","2","3","5",..: 1 1 1 1 1 1 1 1 1 1 ...
set.seed(123)
tree.glass = rpart(Type~., data=train)
print(tree.glass$cptable)
## CP nsplit rel error xerror xstd
## 1 0.22159091 0 1.0000000 1.0909091 0.05814565
## 2 0.06818182 2 0.5568182 0.6022727 0.06400028
## 3 0.04545455 3 0.4886364 0.6136364 0.06419106
## 4 0.02272727 5 0.3977273 0.4886364 0.06118706
## 5 0.01000000 6 0.3750000 0.5454545 0.06280448
cp = min(tree.glass$cptable[4,])
prune.tree.glass = prune(tree.glass, cp = cp)
plot(as.party(tree.glass))

plot(as.party(prune.tree.glass))

rparty.test = predict(prune.tree.glass, newdata=test, type="class")
table(rparty.test, test$Type)
##
## rparty.test 1 2 3 5 6 7
## 1 24 9 1 0 0 0
## 2 3 19 1 1 1 2
## 3 1 1 1 0 0 0
## 5 0 2 0 5 2 0
## 6 0 0 0 0 0 0
## 7 0 1 0 0 0 8
(24+19+1+5+0+8)/82
## [1] 0.695122