library(mlbench)
library(randomForest)
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
library(caret)
## Loading required package: ggplot2
##
## Attaching package: 'ggplot2'
## The following object is masked from 'package:randomForest':
##
## margin
## Loading required package: lattice
library(party)
## Loading required package: grid
## Loading required package: mvtnorm
## Loading required package: modeltools
## Loading required package: stats4
## Loading required package: strucchange
## Loading required package: zoo
##
## Attaching package: 'zoo'
## The following objects are masked from 'package:base':
##
## as.Date, as.Date.numeric
## Loading required package: sandwich
library(gbm)
## Loaded gbm 2.2.2
## This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
library(Cubist)
set.seed(200)
simulated <-mlbench.friedman1(200, sd = 1)
simulated <-cbind(simulated$x, simulated$y)
simulated <-as.data.frame(simulated)
colnames(simulated)[ncol(simulated)] <-"y"
model1 <- randomForest(y ~ ., data = simulated, importance = TRUE, ntree = 1000)
rfImp1 <- varImp(model1, scale = FALSE)
rplot <- varImpPlot(model1)
The random forest model did not significantly use uninformative
predictors (V6-10), with significance scores for V8-9 and V10 being nil.
This indicates that these predictors were not relevant to the results,
indicating their low weight in the ensemble, enhancing generalization
performance and model interpretability.
set.seed(200)
cforest_model <- cforest(y ~ ., data = simulated)
cforest_imp <- varimp(cforest_model, conditional = TRUE)
barplot(cforest_imp, las = 2, main = "Conditional Variable Importance")
No, they do not.A clearer distinction can be seen in the
conditional inference forest: noise variables (V6–V10) are appropriately
down weighted, while genuine predictors (V1–V5) maintain their
significance. When predictors are associated, traditional random forests
have a tendency to exaggerate the significance of noise.
set.seed(200)
boost_model <- gbm(y ~ ., data = simulated, distribution = "gaussian", n.trees = 1000, interaction.depth = 3, shrinkage = 0.01)
summary(boost_model)
## var rel.inf
## V4 V4 27.5155033
## V2 V2 21.6588784
## V1 V1 16.4556919
## duplicate1 duplicate1 11.6925395
## V5 V5 10.9613910
## V3 V3 7.8450574
## V7 V7 1.2692211
## V6 V6 0.9850953
## V9 V9 0.6608039
## V10 V10 0.5200250
## V8 V8 0.4357932
set.seed(123)
coarse <- sample(LETTERS[1:5], 500, replace = TRUE)
fine <- runif(500)
y <- ifelse(coarse == "A", rnorm(500, 5), rnorm(500))
data2 <- data.frame(coarse = as.factor(coarse), fine, y)
model_bias <- randomForest(y ~ ., data = data2, importance = TRUE)
varImpPlot(model_bias)
Interpretation of Figure 8.24 discussed: Reduced bagging and learning rate disperses predictor relevance. Increased values run the danger of overfitting by concentrating on a small number of predictors. By capturing intricate predictor interactions, increasing interaction depth would flatten significance slopes.
library(earth)
## Loading required package: Formula
## Loading required package: plotmo
## Loading required package: plotrix
library(plotrix)
library(readr)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following object is masked from 'package:party':
##
## where
## The following object is masked from 'package:randomForest':
##
## combine
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(AppliedPredictiveModeling)
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ lubridate 1.9.4 ✔ tibble 3.2.1
## ✔ purrr 1.0.4 ✔ tidyr 1.3.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ stringr::boundary() masks strucchange::boundary()
## ✖ dplyr::combine() masks randomForest::combine()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ purrr::lift() masks caret::lift()
## ✖ ggplot2::margin() masks randomForest::margin()
## ✖ dplyr::where() masks party::where()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(rpart)
library(rpart.plot)
library(randomForest)
data("ChemicalManufacturingProcess")
head(ChemicalManufacturingProcess)
## Yield BiologicalMaterial01 BiologicalMaterial02 BiologicalMaterial03
## 1 38.00 6.25 49.58 56.97
## 2 42.44 8.01 60.97 67.48
## 3 42.03 8.01 60.97 67.48
## 4 41.42 8.01 60.97 67.48
## 5 42.49 7.47 63.33 72.25
## 6 43.57 6.12 58.36 65.31
## BiologicalMaterial04 BiologicalMaterial05 BiologicalMaterial06
## 1 12.74 19.51 43.73
## 2 14.65 19.36 53.14
## 3 14.65 19.36 53.14
## 4 14.65 19.36 53.14
## 5 14.02 17.91 54.66
## 6 15.17 21.79 51.23
## BiologicalMaterial07 BiologicalMaterial08 BiologicalMaterial09
## 1 100 16.66 11.44
## 2 100 19.04 12.55
## 3 100 19.04 12.55
## 4 100 19.04 12.55
## 5 100 18.22 12.80
## 6 100 18.30 12.13
## BiologicalMaterial10 BiologicalMaterial11 BiologicalMaterial12
## 1 3.46 138.09 18.83
## 2 3.46 153.67 21.05
## 3 3.46 153.67 21.05
## 4 3.46 153.67 21.05
## 5 3.05 147.61 21.05
## 6 3.78 151.88 20.76
## ManufacturingProcess01 ManufacturingProcess02 ManufacturingProcess03
## 1 NA NA NA
## 2 0.0 0 NA
## 3 0.0 0 NA
## 4 0.0 0 NA
## 5 10.7 0 NA
## 6 12.0 0 NA
## ManufacturingProcess04 ManufacturingProcess05 ManufacturingProcess06
## 1 NA NA NA
## 2 917 1032.2 210.0
## 3 912 1003.6 207.1
## 4 911 1014.6 213.3
## 5 918 1027.5 205.7
## 6 924 1016.8 208.9
## ManufacturingProcess07 ManufacturingProcess08 ManufacturingProcess09
## 1 NA NA 43.00
## 2 177 178 46.57
## 3 178 178 45.07
## 4 177 177 44.92
## 5 178 178 44.96
## 6 178 178 45.32
## ManufacturingProcess10 ManufacturingProcess11 ManufacturingProcess12
## 1 NA NA NA
## 2 NA NA 0
## 3 NA NA 0
## 4 NA NA 0
## 5 NA NA 0
## 6 NA NA 0
## ManufacturingProcess13 ManufacturingProcess14 ManufacturingProcess15
## 1 35.5 4898 6108
## 2 34.0 4869 6095
## 3 34.8 4878 6087
## 4 34.8 4897 6102
## 5 34.6 4992 6233
## 6 34.0 4985 6222
## ManufacturingProcess16 ManufacturingProcess17 ManufacturingProcess18
## 1 4682 35.5 4865
## 2 4617 34.0 4867
## 3 4617 34.8 4877
## 4 4635 34.8 4872
## 5 4733 33.9 4886
## 6 4786 33.4 4862
## ManufacturingProcess19 ManufacturingProcess20 ManufacturingProcess21
## 1 6049 4665 0.0
## 2 6097 4621 0.0
## 3 6078 4621 0.0
## 4 6073 4611 0.0
## 5 6102 4659 -0.7
## 6 6115 4696 -0.6
## ManufacturingProcess22 ManufacturingProcess23 ManufacturingProcess24
## 1 NA NA NA
## 2 3 0 3
## 3 4 1 4
## 4 5 2 5
## 5 8 4 18
## 6 9 1 1
## ManufacturingProcess25 ManufacturingProcess26 ManufacturingProcess27
## 1 4873 6074 4685
## 2 4869 6107 4630
## 3 4897 6116 4637
## 4 4892 6111 4630
## 5 4930 6151 4684
## 6 4871 6128 4687
## ManufacturingProcess28 ManufacturingProcess29 ManufacturingProcess30
## 1 10.7 21.0 9.9
## 2 11.2 21.4 9.9
## 3 11.1 21.3 9.4
## 4 11.1 21.3 9.4
## 5 11.3 21.6 9.0
## 6 11.4 21.7 10.1
## ManufacturingProcess31 ManufacturingProcess32 ManufacturingProcess33
## 1 69.1 156 66
## 2 68.7 169 66
## 3 69.3 173 66
## 4 69.3 171 68
## 5 69.4 171 70
## 6 68.2 173 70
## ManufacturingProcess34 ManufacturingProcess35 ManufacturingProcess36
## 1 2.4 486 0.019
## 2 2.6 508 0.019
## 3 2.6 509 0.018
## 4 2.5 496 0.018
## 5 2.5 468 0.017
## 6 2.5 490 0.018
## ManufacturingProcess37 ManufacturingProcess38 ManufacturingProcess39
## 1 0.5 3 7.2
## 2 2.0 2 7.2
## 3 0.7 2 7.2
## 4 1.2 2 7.2
## 5 0.2 2 7.3
## 6 0.4 2 7.2
## ManufacturingProcess40 ManufacturingProcess41 ManufacturingProcess42
## 1 NA NA 11.6
## 2 0.1 0.15 11.1
## 3 0.0 0.00 12.0
## 4 0.0 0.00 10.6
## 5 0.0 0.00 11.0
## 6 0.0 0.00 11.5
## ManufacturingProcess43 ManufacturingProcess44 ManufacturingProcess45
## 1 3.0 1.8 2.4
## 2 0.9 1.9 2.2
## 3 1.0 1.8 2.3
## 4 1.1 1.8 2.1
## 5 1.1 1.7 2.1
## 6 2.2 1.8 2.0
set.seed(100)
train_idx <- createDataPartition(ChemicalManufacturingProcess$Yield, p = 0.8, list = FALSE)
train_data <- ChemicalManufacturingProcess[train_idx, ]
test_data <- ChemicalManufacturingProcess[-train_idx, ]
set.seed(100)
control <- trainControl(method = "repeatedcv", number = 10, repeats = 3)
# Impute missing values using KNN imputation (or median imputation)
preproc <- preProcess(train_data, method = "knnImpute")
train_data_imputed <- predict(preproc, newdata = train_data)
test_data_imputed <- predict(preproc, newdata = test_data)
# CART
cart_model <- train(Yield ~ ., data = train_data_imputed, method = "rpart", trControl = control)
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,
## : There were missing values in resampled performance measures.
# Random Forest
rf_model <- train(Yield ~ ., data = train_data_imputed, method = "rf", trControl = control, importance = TRUE)
# GBM
gbm_model <- train(Yield ~ ., data = train_data_imputed, method = "gbm", trControl = control, verbose = FALSE)
# Cubist
cubist_model <- train(Yield ~ ., data = train_data_imputed, method = "cubist", trControl = control)
cart_pred <- predict(cart_model, test_data_imputed)
rf_pred <- predict(rf_model, test_data_imputed)
gbm_pred <- predict(gbm_model, test_data_imputed)
cubist_pred <- predict(cubist_model, test_data_imputed)
postResample(cart_pred, test_data_imputed$Yield)
## RMSE Rsquared MAE
## 0.6541459 0.4424792 0.5289075
postResample(rf_pred, test_data_imputed$Yield)
## RMSE Rsquared MAE
## 0.5295761 0.6545200 0.4503856
postResample(gbm_pred, test_data_imputed$Yield)
## RMSE Rsquared MAE
## 0.4095376 0.7822934 0.3380655
postResample(cubist_pred, test_data_imputed$Yield)
## RMSE Rsquared MAE
## 0.3747895 0.8202070 0.3007707
Based on the result, Cubist has the lowest RMSE and MAE
varImp(cubist_model)
## cubist variable importance
##
## only 20 most important variables shown (out of 57)
##
## Overall
## ManufacturingProcess32 100.000
## ManufacturingProcess17 60.825
## ManufacturingProcess13 50.515
## BiologicalMaterial06 46.392
## ManufacturingProcess09 38.144
## BiologicalMaterial02 31.959
## BiologicalMaterial03 22.680
## BiologicalMaterial12 22.680
## ManufacturingProcess39 20.619
## ManufacturingProcess29 13.402
## BiologicalMaterial09 13.402
## BiologicalMaterial04 11.340
## ManufacturingProcess14 10.309
## ManufacturingProcess25 9.278
## ManufacturingProcess28 9.278
## ManufacturingProcess02 9.278
## ManufacturingProcess26 9.278
## ManufacturingProcess27 8.247
## ManufacturingProcess07 8.247
## ManufacturingProcess21 8.247
top10_rf <- varImp(rf_model)$importance %>% arrange(desc(Overall)) %>% head(10)
top10_cubist <- varImp(cubist_model)$importance %>% arrange(desc(Overall)) %>% head(10)
rpart.plot(cart_model$finalModel, type = 3, extra = 101, fallen.leaves = TRUE)
***