Модели на основе деревьев

Модели: деревья решений. Данные: Wage {ISLR}
high.medv = 1, если medv >= 128.68
high.medv = 0, если medv < 128.68
Объясняющие переменные: все остальные
Метод подгонки моделей: бустинг

Загрузка пакетов

library('tree')            
library('ISLR')             
library('GGally')           
library('MASS')             
library('randomForest')       
library('gbm')               

Дерево решений на основе категориальной зависимой переменной

##      crim zn indus chas   nox    rm  age    dis rad tax ptratio  black lstat
## 1 0.00632 18  2.31    0 0.538 6.575 65.2 4.0900   1 296    15.3 396.90  4.98
## 2 0.02731  0  7.07    0 0.469 6.421 78.9 4.9671   2 242    17.8 396.90  9.14
## 3 0.02729  0  7.07    0 0.469 7.185 61.1 4.9671   2 242    17.8 392.83  4.03
## 4 0.03237  0  2.18    0 0.458 6.998 45.8 6.0622   3 222    18.7 394.63  2.94
## 5 0.06905  0  2.18    0 0.458 7.147 54.2 6.0622   3 222    18.7 396.90  5.33
## 6 0.02985  0  2.18    0 0.458 6.430 58.7 6.0622   3 222    18.7 394.12  5.21
##   medv
## 1 24.0
## 2 21.6
## 3 34.7
## 4 33.4
## 5 36.2
## 6 28.7
# head(Wage)
# новая переменная
# High <- ifelse(Wage$medv < 25, "0", "1")
# присоединяем к таблице данных
# Wage <- cbind(Wage, High)
# матричные графики разброса переменных
#p <- ggpairs(Wage[, c(15, 1:4)], aes(color = High))
#suppressMessages(print(p))
#p <- ggpairs(Wage[, c(15, 5:8)], aes(color = High))
#suppressMessages(print(p))
#p <- ggpairs(Wage[, c(15, 9:14)], aes(color = High))
#suppressMessages(print(p))

Модель бинарного дерева

# Wage$High <- as.factor(Wage$High)
# tree.Wage <- tree(High ~ . -medv, Wage)
# summary(tree.Wage)
## 
## Classification tree:
## tree(formula = High ~ . - medv, data = Boston)
## Variables actually used in tree construction:
## [1] "rm"    "lstat" "crim"  "tax"   "age"   "indus" "nox"  
## Number of terminal nodes:  16 
## Residual mean deviance:  0.1975 = 96.77 / 490 
## Misclassification error rate: 0.03557 = 18 / 506
# график результата
# plot(tree.Wage)  
# text(tree.Wage, pretty = 0)   

Построение дерева на обучающей выборке

# обучающая выборка
# et.seed(4)
# train <- sample(1:nrow(Wage), nrow(Wage)/2) 
# строим дерево на обучающей выборке
# tree.Wage <- tree(High ~ . -medv, Wage, subset = train)

Прогнозирование

# tree.pred <- predict(tree.wage, Wage.test, type = "class")
# матрица неточностей
# tbl <- table(tree.pred, High.test)
# tbl
##          High.test
## tree.pred   0   1
##         0 175  16
##         1  10  52
# acc.test <- sum(diag(tbl))/sum(tbl)
# names(acc.test)[length(acc.test)] <- 'Wage.class.tree.all'
# acc.test
## Wage.class.tree.all 
##           0.8972332

Регрессионное дерево на основе непрерывной зависимой переменной

# Wage <- Wage[,-15]
# head(Wage)
##      crim zn indus chas   nox    rm  age    dis rad tax ptratio  black lstat
## 1 0.00632 18  2.31    0 0.538 6.575 65.2 4.0900   1 296    15.3 396.90  4.98
## 2 0.02731  0  7.07    0 0.469 6.421 78.9 4.9671   2 242    17.8 396.90  9.14
## 3 0.02729  0  7.07    0 0.469 7.185 61.1 4.9671   2 242    17.8 392.83  4.03
## 4 0.03237  0  2.18    0 0.458 6.998 45.8 6.0622   3 222    18.7 394.63  2.94
## 5 0.06905  0  2.18    0 0.458 7.147 54.2 6.0622   3 222    18.7 396.90  5.33
## 6 0.02985  0  2.18    0 0.458 6.430 58.7 6.0622   3 222    18.7 394.12  5.21
##   medv
## 1 24.0
## 2 21.6
## 3 34.7
## 4 33.4
## 5 36.2
## 6 28.7
# матричные графики разброса переменных
# p <- ggpairs(Wage[, c(14, 1:4)])
# suppressMessages(print(p))
# p <- ggpairs(Wage[, c(14, 5:8)])
# suppressMessages(print(p))
# p <- ggpairs(Wage[, c(14, 9:13)])
# suppressMessages(print(p))

Построение дерева регрессии для зависимой переменной

# обучаем модель
# tree.Wage <- tree(medv ~ ., Wage, subset = train)
# summary(tree.Wage)

# визуализация
# plot(tree.Wage)
# text(tree.Wage, pretty = 0)

# обрезка дерева
# cv.Wage <- cv.tree(tree.Wage)
# plot(cv.Wage$size, cv.Wage$dev, type = 'b')
# opt.size <- cv.Wage$size[cv.Wage$dev == min(cv.Wage$dev)]
# abline(v = opt.size, col = 'red', 'lwd' = 2)  
# mtext(opt.size, at = opt.size, side = 1, col = 'red', line = 1)

# прогноз по лучшей модели
# yhat <- predict(tree.Wage, newdata = Wage[-train, ])
# Wage.test <- Wage[-train, "medv"]

# график "прогноз-реализация"
# plot(yhat, Wage.test)

# линия идеального прогноза
# abline(0, 1)
## 
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "lstat" "rm"    "age"   "nox"  
## Number of terminal nodes:  10 
## Residual mean deviance:  14.15 = 3438 / 243 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -21.5000  -1.9220  -0.1065   0.0000   2.0000  16.5900

## Wage.regr.tree.8 
##         24.08221

Бустинг

# set.seed(4)
# boost.wage<- gbm(medv ~ ., data = wage[train, ], distribution = "gaussian",n.trees = 5000, interaction.depth = 4)
# summary(boost.wage) 

##             var      rel.inf
## lstat     lstat 42.477134548
## rm           rm 23.879970987
## dis         dis  5.866889022
## crim       crim  5.196153075
## ptratio ptratio  4.972160338
## black     black  4.878371547
## age         age  4.686835907
## nox         nox  4.122856333
## tax         tax  1.563061858
## indus     indus  1.411390208
## rad         rad  0.723067524
## zn           zn  0.219414595
## chas       chas  0.002694058
# графики частной зависимости для двух наиболее важных предикторов
# par(mfrow = c(1, 2))
# plot(boost.wage, i = "rm")
# plot(boost.wage, i = "lstat")

# прогноз
# yhat.boost <- predict(boost.wage, newdata = Wage[-train, ], n.trees = 5000)

## Wage.regr.tree.8   Wage.boost.opt 
##         24.08221         16.31184
## Wage.regr.tree.8   Wage.boost.opt   Wage.boost.0.1 
##         24.08221         16.31184         15.97801