随机森林(Random Forest,RF)是一种集成(ensemble)学习器,他利用Bootstrap重抽样方法从原始样本中抽取多个样本进行决策树(decision tree)建模,然后将这些决策树组合在一起,通过对所有决策树结果的平均(Mean)或投票(Vote)得出最终预测的回归(Regression)或分类(Classification)的结果。
大量的理论和实证研究都证明了随机森林:
研究随机森林,就必须涉及到决策树,因为随机森林的基学习器就是没有兼职的决策树。关于决策树的描述在另一篇文档中.
随机森林的产生,是为了克服决策树在回归和分类预测方面有诸多缺点,结合单个树学习器组合成多个学习器的思想 生成多棵决策树,这些决策树不需要都有很高的精度,并让所有的决策树通过投票的形式进行决策。
构建Random Forest的主要步骤:
(1) 每棵决策树都对应一个训练集数据,要构建M棵决策树,就需要产生对应数量(M)的训练集,从原始训练集中产生M个训练子集要用到统计抽样技术。现有的统计抽样技术很多,按照抽样是否放回主要包括以下两种:
(2) Bagging和Boosting方法都是可放回的抽样方法,但两者间存在很大的差别:
串行的关系,这对算法的执行是一个很大的挑战,以为每次执行都要等待上次的结果才能继续。而Bagging方法就不存在这个问题,这为算法的并行处理提供了很好的支持。(3) 随机森林算法在生成的过程中,主要采用bagging方法,也就是Bootstrap抽样。
随机森林算法为每个Bootstrap抽样训练子集分别建立一棵决策树,生成M棵决策树从而形成“森林”。每棵树任其生长,不需要剪枝。其中涉及两个主要过程:
(1)节点分裂
(2)随机特征变量的随机选取
randomForest包:library(randomForest)ntreemtryrandomForest包,及其他有用的包library(caret)
library(ggplot2)
library(randomForest)
## Implements Breiman's random forest algorithm (based on Breiman and Cutler's original Fortran code) for classification and regression. It can also be used in unsupervised mode for assessing proximities among data points.
randomForest(formula, data = NULL,
#x, y = NULL,
#xtest = NULL, ytest = NULL,
#subset,
#na.action = na.fail,
ntree = 500,
mtry = if (!is.null(y) && !is.factor(y)) {
max(floor(ncol(x)/3), 1)
else floor(sqrt(ncol(x)))
},
#replace = TRUE,
#classwt = NULL,
#cutoff,
#strata,
sampsize = if (replace) nrow(x) else ceiling(.632*nrow(x)),
nodesize = if (!is.null(y) && !is.factor(y)) 5 else 1,
maxnodes = NULL,
importance = FALSE,
localImp = FALSE,
nPerm = 1,
proximity,
oob.prox = proximity,
norm.votes = TRUE,
do.trace = FALSE,
keep.forest = !is.null(y) && is.null(xtest),
corr.bias = FALSE,
keep.inbag = FALSE,
...)
## Print the 'randomForest'
print(rf_Model, ...)
# Plot method for randomForest objects
plot(rf_Model, type = "l", main = "")
# Extract variable importance measure
importance(rf_Model, type = (1 or 2), class = NULL, scale = TRUE)
# predict method for random forest objects
predict(rf_Model,
test,
type = "response",
norm.votes = TRUE,
predict.all = FALSE,
proximity = FALSE,
nodes = FALSE,
cutoff)
# Random Forest Cross-Valdidation for feature selection
rfcv(trainx, trainy,
cv.fold = 5,
scale = "log",
step = 0.5,
mtry = function(p) max(1, floor(sqrt(p))),
recursive = FALSE,
...)
# Size of trees in an ensemble
treesize(x, terminal = TRUE)
# Tune randomForest for the optimal mtry parameter
tuneRF(credit[, -21],
credit[, 21],
mtryStart,
ntreeTry = 50,
stepFactor = 2,
improve = 0.05,
trace = TRUE,
plot = TRUE,
doBest = FALSE, ...)
# Variable Importance Plot
varImpPlot(rf_Model,
sort = TRUE,
n.var = min(30, nrow(rf_Model$importance)),
type=NULL,
class=NULL,
scale=TRUE,
main=deparse(substitute(x)))
# Variables used in a random forest
varUsed(rf_Model, by.tree = FALSE, count = TRUE)
## German Credit data
data = read.csv("http://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/german.data",
header = FALSE,
sep = "")
names(data) = c(paste0(rep("x", 20), 1:20), "y")
data$y = as.integer(data$y) - 1
library(magrittr)
data = sapply(data, as.integer) %>%
as.data.frame() %>%
sapply(function(x) x/max(x)) %>%
as.data.frame()
head(data)
## x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 x11 x12
## 1 0.25 0.08333333 1.0 0.5 0.06344985 1.0 1.0 1.00 0.75 0.3333333 1.00 0.25
## 2 0.50 0.66666667 0.6 0.5 0.32300261 0.2 0.6 0.50 0.50 0.3333333 0.50 0.25
## 3 1.00 0.16666667 1.0 0.8 0.11376465 0.2 0.8 0.50 0.75 0.3333333 0.75 0.25
## 4 0.25 0.58333333 0.6 0.4 0.42781155 0.2 0.8 0.50 0.75 1.0000000 1.00 0.50
## 5 0.25 0.33333333 0.8 0.1 0.26432914 0.2 0.6 0.75 0.75 0.3333333 1.00 1.00
## 6 1.00 0.50000000 0.6 0.8 0.49147851 1.0 0.6 0.50 0.75 0.3333333 1.00 1.00
## x13 x14 x15 x16 x17 x18 x19 x20 y
## 1 0.8933333 1 0.6666667 0.50 0.75 0.5 1.0 0.5 0
## 2 0.2933333 1 0.6666667 0.25 0.75 0.5 0.5 0.5 1
## 3 0.6533333 1 0.6666667 0.25 0.50 1.0 0.5 0.5 0
## 4 0.6000000 1 1.0000000 0.25 0.75 1.0 0.5 0.5 0
## 5 0.7066667 1 1.0000000 0.50 0.75 1.0 0.5 0.5 1
## 6 0.4666667 1 1.0000000 0.25 0.50 1.0 1.0 0.5 0
# rf_Model = randomForest(y ~ ., data = data, ntree = 500)
Method: randomForest Data: credit
## Classification:
## data(iris)
set.seed(71)
iris.rf <- randomForest(Species ~ ., data=iris, importance=TRUE, proximity=TRUE)
print(iris.rf)
##
## Call:
## randomForest(formula = Species ~ ., data = iris, importance = TRUE, proximity = TRUE)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 5.33%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 50 0 0 0.00
## versicolor 0 46 4 0.08
## virginica 0 4 46 0.08
## Look at variable importance:
round(importance(iris.rf), 2)
## setosa versicolor virginica MeanDecreaseAccuracy
## Sepal.Length 6.04 7.85 7.93 11.51
## Sepal.Width 4.40 1.03 5.44 5.40
## Petal.Length 21.76 31.33 29.64 32.94
## Petal.Width 22.84 32.67 31.68 34.50
## MeanDecreaseGini
## Sepal.Length 8.77
## Sepal.Width 2.19
## Petal.Length 42.54
## Petal.Width 45.77
## Do MDS on 1 - proximity:
iris.mds <- cmdscale(1 - iris.rf$proximity, eig=TRUE)
op <- par(pty="s")
pairs(cbind(iris[,1:4], iris.mds$points), cex=0.6, gap=0,
col=c("red", "green", "blue")[as.numeric(iris$Species)],
main="Iris Data: Predictors and MDS of Proximity Based on RandomForest")
par(op)
print(iris.mds$GOF)
## [1] 0.7282700 0.7903363
## The `unsupervised' case:
set.seed(17)
iris.urf <- randomForest(iris[, -5])
MDSplot(iris.urf, iris$Species)
## Loading required package: RColorBrewer
## stratified sampling: draw 20, 30, and 20 of the species to grow each tree.
(iris.rf2 <- randomForest(iris[1:4], iris$Species,
sampsize=c(20, 30, 20)))
##
## Call:
## randomForest(x = iris[1:4], y = iris$Species, sampsize = c(20, 30, 20))
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 5.33%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 50 0 0 0.00
## versicolor 0 47 3 0.06
## virginica 0 5 45 0.10
## Regression:
## data(airquality)
set.seed(131)
ozone.rf <- randomForest(Ozone ~ ., data=airquality, mtry=3,
importance=TRUE, na.action=na.omit)
print(ozone.rf)
##
## Call:
## randomForest(formula = Ozone ~ ., data = airquality, mtry = 3, importance = TRUE, na.action = na.omit)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 3
##
## Mean of squared residuals: 303.8304
## % Var explained: 72.31
## Show "importance" of variables: higher value mean more important:
round(importance(ozone.rf), 2)
## %IncMSE IncNodePurity
## Solar.R 11.09 10534.24
## Wind 23.50 43833.13
## Temp 42.03 55218.05
## Month 4.07 2032.65
## Day 2.63 7173.19
## "x" can be a matrix instead of a data frame:
set.seed(17)
x <- matrix(runif(5e2), 100)
y <- gl(2, 50)
(myrf <- randomForest(x, y))
##
## Call:
## randomForest(x = x, y = y)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 45%
## Confusion matrix:
## 1 2 class.error
## 1 30 20 0.4
## 2 25 25 0.5
(predict(myrf, x))
## 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
## 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
## 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
## 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2
## 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
## 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
## 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## 91 92 93 94 95 96 97 98 99 100
## 2 2 2 2 2 2 2 2 2 2
## Levels: 1 2
## "complicated" formula:
(swiss.rf <- randomForest(sqrt(Fertility) ~ . - Catholic + I(Catholic < 50),
data=swiss))
##
## Call:
## randomForest(formula = sqrt(Fertility) ~ . - Catholic + I(Catholic < 50), data = swiss)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 1
##
## Mean of squared residuals: 0.3207372
## % Var explained: 45.54
(predict(swiss.rf, swiss))
## Courtelary Delemont Franches-Mnt Moutier Neuveville
## 8.544219 8.977287 9.118769 8.781467 8.522875
## Porrentruy Broye Glane Gruyere Sarine
## 8.924009 8.888588 9.140677 8.952968 8.875341
## Veveyse Aigle Aubonne Avenches Cossonay
## 9.114092 7.936236 8.328221 8.229574 8.023034
## Echallens Grandson Lausanne La Vallee Lavaux
## 8.274981 8.434141 7.742152 7.598992 8.180066
## Morges Moudon Nyone Orbe Oron
## 7.987572 8.354154 7.860249 7.873372 8.483563
## Payerne Paysd'enhaut Rolle Vevey Yverdon
## 8.551495 8.410619 8.035697 7.846781 8.363983
## Conthey Entremont Herens Martigwy Monthey
## 8.844938 8.657802 8.776533 8.570825 8.892461
## St Maurice Sierre Sion Boudry La Chauxdfnd
## 8.477949 8.919503 8.664517 8.277097 8.033753
## Le Locle Neuchatel Val de Ruz ValdeTravers V. De Geneve
## 8.170280 7.760338 8.553581 8.152936 6.781376
## Rive Droite Rive Gauche
## 7.500274 7.194497
## Test use of 32-level factor as a predictor:
set.seed(1)
x <- data.frame(x1=gl(53, 10), x2=runif(530), y=rnorm(530))
(rf1 <- randomForest(x[-3], x[[3]], ntree=10))
##
## Call:
## randomForest(x = x[-3], y = x[[3]], ntree = 10)
## Type of random forest: regression
## Number of trees: 10
## No. of variables tried at each split: 1
##
## Mean of squared residuals: 1.49581
## % Var explained: -34.99
## Grow no more than 4 nodes per tree:
(treesize(randomForest(Species ~ ., data=iris, maxnodes=4, ntree=30)))
## [1] 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
## test proximity in regression
iris.rrf <- randomForest(iris[-1], iris[[1]], ntree=101, proximity=TRUE, oob.prox=FALSE)
str(iris.rrf$proximity)
## num [1:150, 1:150] 1 0.337 0.327 0.356 0.891 ...
## - attr(*, "dimnames")=List of 2
## ..$ : chr [1:150] "1" "2" "3" "4" ...
## ..$ : chr [1:150] "1" "2" "3" "4" ...
Method: caret Data: iris
data(iris)
inTrain = createDataPartition(y = iris$Species, p = 0.8, list = FALSE)
training = iris[inTrain, ]
testing = iris[-inTrain, ]
modFit = train(Species ~ ., data = training, method = "rf", prox = TRUE)
modFit
## Random Forest
##
## 120 samples
## 4 predictor
## 3 classes: 'setosa', 'versicolor', 'virginica'
##
## No pre-processing
## Resampling: Bootstrapped (25 reps)
##
## Summary of sample sizes: 120, 120, 120, 120, 120, 120, ...
##
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa Accuracy SD Kappa SD
## 2 0.9444936 0.9155944 0.03712069 0.05657357
## 3 0.9470468 0.9194600 0.03640573 0.05556211
## 4 0.9425767 0.9127153 0.03709553 0.05650571
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 3.
getTree(modFit$finalModel, k = 2)
## left daughter right daughter split var split point status prediction
## 1 2 3 4 0.80 1 0
## 2 0 0 0 0.00 -1 1
## 3 4 5 3 4.85 1 0
## 4 6 7 4 1.65 1 0
## 5 8 9 4 1.65 1 0
## 6 0 0 0 0.00 -1 2
## 7 10 11 2 3.10 1 0
## 8 12 13 3 4.95 1 0
## 9 0 0 0 0.00 -1 3
## 10 0 0 0 0.00 -1 3
## 11 0 0 0 0.00 -1 2
## 12 0 0 0 0.00 -1 2
## 13 0 0 0 0.00 -1 3
irisP = classCenter(training[, c(3, 4)], training$Species, modFit$finalModel$prox)
irisP = as.data.frame(irisP)
irisP$Species = rownames(irisP)
p = qplot(Petal.Width, Petal.Length, data = training, col = Species)
p + geom_point(aes(x = Petal.Width, y = Petal.Length, col = Species), size = 5,
shape = 4, data = irisP)
pred = predict(modFit, testing)
testing$predRight = pred == testing$Species
confusionMatrix(pred, testing$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 9 0
## virginica 0 1 10
##
## Overall Statistics
##
## Accuracy : 0.9667
## 95% CI : (0.8278, 0.9992)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 2.963e-13
##
## Kappa : 0.95
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.9000 1.0000
## Specificity 1.0000 1.0000 0.9500
## Pos Pred Value 1.0000 1.0000 0.9091
## Neg Pred Value 1.0000 0.9524 1.0000
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.3000 0.3333
## Detection Prevalence 0.3333 0.3000 0.3667
## Balanced Accuracy 1.0000 0.9500 0.9750
p1 = ggplot(data = testing, aes(x = Petal.Width, y = Petal.Length,
color = predRight))
p1 +
geom_point() +
ggtitle("newdata Predictions") +
theme_bw()