This article illustrates that an Rmd version of the code in my article Improving the Performance of caret::train(), which is posted on github.com, works correctly with knitr, other than the fact that I didn’t take the time to clean up the HTML tables from the .md version.
During the December 2015 run of the Practical Machine Learning course within the Johns Hopkins University Data Science Specialization offered via coursera.org, many students struggled with the slow performance of some of the machine learning models, especially Random Forest.
Although the Community Teaching Assistants provided pointers on how to improve the performance of the caret::train() function by using the parallel package in conjunction with the trainControl() function in caret, many students were not able to construct a series of function calls that enabled caret::train() to run fast enough to be considered “usable” by the students. Consequently, students used the randomForest::randomForest() function to develop predictions for the course project.
This approach takes away one of the key advantages of the caret package: its ability to estimate an out of sample error by aggregating the accuracy analysis across a series of training runs. This is because caret automates the process of fitting multiple versions of a given model by varying its parameters and/or folds within a resampling / cross-validation process.
To improve processing time of the multiple executions of the train() function, caret supports the parallel processing capabilities of the parallel package. Unfortunately, the documentation of parallel processing with caret uses a technique, the doMC package, which is not available for Microsoft Windows versions of R.
Fortunately, the parallel package works on R across all major operating system platforms: Linux, Mac OSX, and Windows. One’s ability to run these models in parallel is often the difference between using a highly effective algorithm like random forest versus a less effective but more computationally efficient algorithm (such as linear discriminant analysis).
One other tradeoff that we made in this analysis was changing the resampling method from the default of bootstrapping to k-fold cross-validation. The change in resampling technique may trade processing performance for reduced model accuracy. However, our analysis shows that the 5 fold cross-validation resampling technique delivered the same accuracy as the more computationally expensive bootstrapping technique.
Finally, we note that caret::train() supports a wide variety of tuning parameters that vary by model type. For the purposes of this analysis, we chose only to vary the resampling method for train(x,y,method="rf",...), leaving other parameters such as mtry constant.
Once a person works through the varied sources of documentation on the machine learning models and supporting R packages, the process for executing a random forest model (or any other model) in caret::train() is relatively straightforward, and includes the following steps.
For the purpose of illustrating the syntax required for parallel processing, we’ll use the Sonar data set that is also used as the example in the caret model training documentation.
intervalStart <- Sys.time()
library(mlbench)
data(Sonar)
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
set.seed(95014)
inTraining <- createDataPartition(Sonar$Class, p = .75, list=FALSE)
training <- Sonar[inTraining,]
testing <- Sonar[-inTraining,]
# set up x and y to avoid slowness of caret() with model syntax
y <- training[,61]
x <- training[,-61]
Parallel processing in caret can be accomplished with the parallel and doParallel packages. The following code loads the required libraries (note, these libraries also depend on the iterators and foreach libraries).
library(parallel)
library(doParallel)
## Loading required package: foreach
## Loading required package: iterators
cluster <- makeCluster(detectCores() - 1) # convention to leave 1 core for OS
registerDoParallel(cluster)
The most critical arguments for the trainControl function are the resampling metdhod method, the number that specifies the quantity of folds for k-fold cross-validation, and allowParallel which tells caret to use the cluster that we’ve registered in the previous step.
fitControl <- trainControl(method = "cv",
number = 5,
allowParallel = TRUE)
Next, we use caret::train() to train the model, using the trainControl() object that we just created.
system.time(fit <- train(x,y, method="rf",data=Sonar,trControl = fitControl))
## Loading required package: randomForest
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
## user system elapsed
## 0.88 0.08 7.17
system.time(fit <- train(Class ~ ., method="rf",data=Sonar,trControl = fitControl))
## user system elapsed
## 0.70 0.00 2.52
After processing the data, we explicitly shut down the cluster by calling the stopCluster() function.
stopCluster(cluster)
At this point we have a trained model in the fit object, and can take a number of steps to evaluate the suitability of this model, including accuracy and a confusion matrix that is based on comparing the modeled data to the held out folds.
fit
## Random Forest
##
## 208 samples
## 60 predictor
## 2 classes: 'M', 'R'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 167, 166, 167, 165, 167
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.8174837 0.6300874
## 31 0.8081814 0.6110130
## 60 0.7889069 0.5724638
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
fit$resample
## Accuracy Kappa Resample
## 1 0.9512195 0.9012048 Fold1
## 2 0.8809524 0.7597254 Fold2
## 3 0.7804878 0.5505481 Fold5
## 4 0.7674419 0.5232816 Fold4
## 5 0.7073171 0.4156770 Fold3
confusionMatrix.train(fit)
## Cross-Validated (5 fold) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction M R
## M 47.6 12.5
## R 5.8 34.1
##
## Accuracy (average) : 0.8173
If desired, at this point one can make a prediction on the held out testing data partition. Since the primary purpose of this article is to illustrate the syntax required for parallel processing and to discuss its impact on the course project for Practical Machine Learning, we will not fit the testing data or evaluate the model accuracy here.
Returning our attention to the Practical Machine Learning course project data set, we compared the performance of parallel versus single-threaded processing on a random forest model with caret, and evaluated the performance of the parallel-processing version across four different computers.
As illustrated in the following table, multi-threading has a significant, positive impact on the performance of the caret::train() function. As expected, the difference in processing times for the linear discriminant model was negligible. However, for the random forest, the multi-threaded version finished 58% faster than the single-threaded version (as measured on the HP Omen laptop with Intel® Core⢠i7-4720HQ processor).
| Machine | Algorithm | Threading Model | Result |
|---|---|---|---|
| HP Omen laptop | Linear Discriminant Analysis | Multi-threaded | 2.38 seconds |
| HP Omen laptop | Linear Discriminant Analysis | Single-threaded | 2.41 seconds |
| HP Omen laptop | Random Forest | Multi-threaded | 193.2 seconds |
| HP Omen laptop | Random Forest | Single-threaded | 462.6 seconds |
This section of the analysis used four different laptop computers to assess the performance of caret::train(). CPU speed, number of processor cores, and disk speed (to a lesser extent) all impact runtime performance. All four machines have Intel-based processors with multiple cores, and each core contains two processing threads that can be assigned to execute instructions in parallel. As expected, the machine with the largest number of cores and fastest disk speed returns the fastest response time, completing the 5 k-fold cross-validation model in 3.22 minutes.
Since most students in Practical Machine Learning have older (and slower) hardware than the machines I typically use to complete the work in the Data Science Specialization courses, I also ran the tests on a Windows-based tablet: the HP Envy X2.
Surprisingly, the random forest algorithm for the Practical Machine Learning course project runs flawlessly on the tablet. The runtime performance is very slow compared to the other machines, requiring 1 hour 15 minutes to complete the random forest, using all 4 threads across the two cores in its Intel Atom-based processor.
Finally, to illustrate the impact that the resampling technique has on the runtime performance, we fit the training data for the Practical Machine Learning course project on the HP Omen laptop with bootstrapping as the resampling method. The bootstrapping resampling method caused a significant increase in processing time, requiring 17 minutes instead of 3.22 minutes to train the model. Since the cross-validation resampling method resulted in an accuracy of .9945, the bootstrapping resampling method had no positive impact on model accuracy.
|
Machine |
Model |
Resampling Technique |
Result |
|---|---|---|---|
| HP Omen laptop | Random Forest | CV | 03.22 minutes |
| HP Spectre x360 laptop | Random Forest | CV | 04.65 minutes |
| Macbook Pro laptop | Random Forest | CV | 06.56 minutes |
| HP Omen laptop | Random Forest | Bootstrap | 17.00 minutes |
| HP Envy X2 laptop | Random Forest | CV | 74.97 minutes |
Hardware specifications for the computers used in the performance timings in this article are listed below.
| Computer | Configuration |
|---|---|
| Apple Macbook Pro |
|
| HP Envy X2 tablet |
|
| HP Omen laptop |
|
HP Spectre X360 laptop |
|
sessionInfo()
## R version 3.3.0 (2016-05-03)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 10586)
##
## locale:
## [1] LC_COLLATE=English_United States.1252
## [2] LC_CTYPE=English_United States.1252
## [3] LC_MONETARY=English_United States.1252
## [4] LC_NUMERIC=C
## [5] LC_TIME=English_United States.1252
##
## attached base packages:
## [1] parallel stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] randomForest_4.6-12 doParallel_1.0.10 iterators_1.0.8
## [4] foreach_1.4.3 caret_6.0-68 ggplot2_2.1.0
## [7] lattice_0.20-33 mlbench_2.1-1
##
## loaded via a namespace (and not attached):
## [1] Rcpp_0.12.5 compiler_3.3.0 formatR_1.4
## [4] nloptr_1.0.4 plyr_1.8.3 class_7.3-14
## [7] tools_3.3.0 digest_0.6.9 lme4_1.1-12
## [10] evaluate_0.9 nlme_3.1-128 gtable_0.2.0
## [13] mgcv_1.8-12 Matrix_1.2-6 yaml_2.1.13
## [16] SparseM_1.7 e1071_1.6-7 stringr_1.0.0
## [19] knitr_1.13 MatrixModels_0.4-1 stats4_3.3.0
## [22] grid_3.3.0 nnet_7.3-12 rmarkdown_0.9.6
## [25] minqa_1.2.4 reshape2_1.4.1 car_2.1-2
## [28] magrittr_1.5 scales_0.4.0 codetools_0.2-14
## [31] htmltools_0.3.5 MASS_7.3-45 splines_3.3.0
## [34] pbkrtest_0.4-6 colorspace_1.2-6 quantreg_5.26
## [37] stringi_1.1.1 munsell_0.4.3
intervalEnd <- Sys.time()
paste("Random Forest analysis of Sonar data took",intervalEnd - intervalStart,attr(intervalEnd - intervalStart,"units"))
## [1] "Random Forest analysis of Sonar data took 12.7603738307953 secs"