Модели: деревья решений. Данные: Wage {ISLR}
high.medv = 1, если medv >= 128.68
high.medv = 0, если medv < 128.68
Объясняющие переменные: все остальные
Метод подгонки моделей: бустинг
library('tree')
library('ISLR')
library('GGally')
library('MASS')
library('randomForest')
library('gbm')
## crim zn indus chas nox rm age dis rad tax ptratio black lstat
## 1 0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98
## 2 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14
## 3 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03
## 4 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94
## 5 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33
## 6 0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12 5.21
## medv
## 1 24.0
## 2 21.6
## 3 34.7
## 4 33.4
## 5 36.2
## 6 28.7
# head(Wage)
# новая переменная
# High <- ifelse(Wage$medv < 25, "0", "1")
# присоединяем к таблице данных
# Wage <- cbind(Wage, High)
# матричные графики разброса переменных
#p <- ggpairs(Wage[, c(15, 1:4)], aes(color = High))
#suppressMessages(print(p))
#p <- ggpairs(Wage[, c(15, 5:8)], aes(color = High))
#suppressMessages(print(p))
#p <- ggpairs(Wage[, c(15, 9:14)], aes(color = High))
#suppressMessages(print(p))
# Wage$High <- as.factor(Wage$High)
# tree.Wage <- tree(High ~ . -medv, Wage)
# summary(tree.Wage)
##
## Classification tree:
## tree(formula = High ~ . - medv, data = Boston)
## Variables actually used in tree construction:
## [1] "rm" "lstat" "crim" "tax" "age" "indus" "nox"
## Number of terminal nodes: 16
## Residual mean deviance: 0.1975 = 96.77 / 490
## Misclassification error rate: 0.03557 = 18 / 506
# график результата
# plot(tree.Wage)
# text(tree.Wage, pretty = 0)
# обучающая выборка
# et.seed(4)
# train <- sample(1:nrow(Wage), nrow(Wage)/2)
# строим дерево на обучающей выборке
# tree.Wage <- tree(High ~ . -medv, Wage, subset = train)
# tree.pred <- predict(tree.wage, Wage.test, type = "class")
# матрица неточностей
# tbl <- table(tree.pred, High.test)
# tbl
## High.test
## tree.pred 0 1
## 0 175 16
## 1 10 52
# acc.test <- sum(diag(tbl))/sum(tbl)
# names(acc.test)[length(acc.test)] <- 'Wage.class.tree.all'
# acc.test
## Wage.class.tree.all
## 0.8972332
# Wage <- Wage[,-15]
# head(Wage)
## crim zn indus chas nox rm age dis rad tax ptratio black lstat
## 1 0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98
## 2 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14
## 3 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03
## 4 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94
## 5 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33
## 6 0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12 5.21
## medv
## 1 24.0
## 2 21.6
## 3 34.7
## 4 33.4
## 5 36.2
## 6 28.7
# матричные графики разброса переменных
# p <- ggpairs(Wage[, c(14, 1:4)])
# suppressMessages(print(p))
# p <- ggpairs(Wage[, c(14, 5:8)])
# suppressMessages(print(p))
# p <- ggpairs(Wage[, c(14, 9:13)])
# suppressMessages(print(p))
# обучаем модель
# tree.Wage <- tree(medv ~ ., Wage, subset = train)
# summary(tree.Wage)
# визуализация
# plot(tree.Wage)
# text(tree.Wage, pretty = 0)
# обрезка дерева
# cv.Wage <- cv.tree(tree.Wage)
# plot(cv.Wage$size, cv.Wage$dev, type = 'b')
# opt.size <- cv.Wage$size[cv.Wage$dev == min(cv.Wage$dev)]
# abline(v = opt.size, col = 'red', 'lwd' = 2)
# mtext(opt.size, at = opt.size, side = 1, col = 'red', line = 1)
# прогноз по лучшей модели
# yhat <- predict(tree.Wage, newdata = Wage[-train, ])
# Wage.test <- Wage[-train, "medv"]
# график "прогноз-реализация"
# plot(yhat, Wage.test)
# линия идеального прогноза
# abline(0, 1)
##
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "lstat" "rm" "age" "nox"
## Number of terminal nodes: 10
## Residual mean deviance: 14.15 = 3438 / 243
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -21.5000 -1.9220 -0.1065 0.0000 2.0000 16.5900
## Wage.regr.tree.8
## 24.08221
# set.seed(4)
# boost.wage<- gbm(medv ~ ., data = wage[train, ], distribution = "gaussian",n.trees = 5000, interaction.depth = 4)
# summary(boost.wage)
## var rel.inf
## lstat lstat 42.477134548
## rm rm 23.879970987
## dis dis 5.866889022
## crim crim 5.196153075
## ptratio ptratio 4.972160338
## black black 4.878371547
## age age 4.686835907
## nox nox 4.122856333
## tax tax 1.563061858
## indus indus 1.411390208
## rad rad 0.723067524
## zn zn 0.219414595
## chas chas 0.002694058
# графики частной зависимости для двух наиболее важных предикторов
# par(mfrow = c(1, 2))
# plot(boost.wage, i = "rm")
# plot(boost.wage, i = "lstat")
# прогноз
# yhat.boost <- predict(boost.wage, newdata = Wage[-train, ], n.trees = 5000)
## Wage.regr.tree.8 Wage.boost.opt
## 24.08221 16.31184
## Wage.regr.tree.8 Wage.boost.opt Wage.boost.0.1
## 24.08221 16.31184 15.97801