STEP 1.Follow Along

5.5.1 Decision trees

library(mlbench)
## Warning: package 'mlbench' was built under R version 4.5.1
data(Sonar)
head(Sonar)
##       V1     V2     V3     V4     V5     V6     V7     V8     V9    V10    V11
## 1 0.0200 0.0371 0.0428 0.0207 0.0954 0.0986 0.1539 0.1601 0.3109 0.2111 0.1609
## 2 0.0453 0.0523 0.0843 0.0689 0.1183 0.2583 0.2156 0.3481 0.3337 0.2872 0.4918
## 3 0.0262 0.0582 0.1099 0.1083 0.0974 0.2280 0.2431 0.3771 0.5598 0.6194 0.6333
## 4 0.0100 0.0171 0.0623 0.0205 0.0205 0.0368 0.1098 0.1276 0.0598 0.1264 0.0881
## 5 0.0762 0.0666 0.0481 0.0394 0.0590 0.0649 0.1209 0.2467 0.3564 0.4459 0.4152
## 6 0.0286 0.0453 0.0277 0.0174 0.0384 0.0990 0.1201 0.1833 0.2105 0.3039 0.2988
##      V12    V13    V14    V15    V16    V17    V18    V19    V20    V21    V22
## 1 0.1582 0.2238 0.0645 0.0660 0.2273 0.3100 0.2999 0.5078 0.4797 0.5783 0.5071
## 2 0.6552 0.6919 0.7797 0.7464 0.9444 1.0000 0.8874 0.8024 0.7818 0.5212 0.4052
## 3 0.7060 0.5544 0.5320 0.6479 0.6931 0.6759 0.7551 0.8929 0.8619 0.7974 0.6737
## 4 0.1992 0.0184 0.2261 0.1729 0.2131 0.0693 0.2281 0.4060 0.3973 0.2741 0.3690
## 5 0.3952 0.4256 0.4135 0.4528 0.5326 0.7306 0.6193 0.2032 0.4636 0.4148 0.4292
## 6 0.4250 0.6343 0.8198 1.0000 0.9988 0.9508 0.9025 0.7234 0.5122 0.2074 0.3985
##      V23    V24    V25    V26    V27    V28    V29    V30    V31    V32    V33
## 1 0.4328 0.5550 0.6711 0.6415 0.7104 0.8080 0.6791 0.3857 0.1307 0.2604 0.5121
## 2 0.3957 0.3914 0.3250 0.3200 0.3271 0.2767 0.4423 0.2028 0.3788 0.2947 0.1984
## 3 0.4293 0.3648 0.5331 0.2413 0.5070 0.8533 0.6036 0.8514 0.8512 0.5045 0.1862
## 4 0.5556 0.4846 0.3140 0.5334 0.5256 0.2520 0.2090 0.3559 0.6260 0.7340 0.6120
## 5 0.5730 0.5399 0.3161 0.2285 0.6995 1.0000 0.7262 0.4724 0.5103 0.5459 0.2881
## 6 0.5890 0.2872 0.2043 0.5782 0.5389 0.3750 0.3411 0.5067 0.5580 0.4778 0.3299
##      V34    V35    V36    V37    V38    V39    V40    V41    V42    V43    V44
## 1 0.7547 0.8537 0.8507 0.6692 0.6097 0.4943 0.2744 0.0510 0.2834 0.2825 0.4256
## 2 0.2341 0.1306 0.4182 0.3835 0.1057 0.1840 0.1970 0.1674 0.0583 0.1401 0.1628
## 3 0.2709 0.4232 0.3043 0.6116 0.6756 0.5375 0.4719 0.4647 0.2587 0.2129 0.2222
## 4 0.3497 0.3953 0.3012 0.5408 0.8814 0.9857 0.9167 0.6121 0.5006 0.3210 0.3202
## 5 0.0981 0.1951 0.4181 0.4604 0.3217 0.2828 0.2430 0.1979 0.2444 0.1847 0.0841
## 6 0.2198 0.1407 0.2856 0.3807 0.4158 0.4054 0.3296 0.2707 0.2650 0.0723 0.1238
##      V45    V46    V47    V48    V49    V50    V51    V52    V53    V54    V55
## 1 0.2641 0.1386 0.1051 0.1343 0.0383 0.0324 0.0232 0.0027 0.0065 0.0159 0.0072
## 2 0.0621 0.0203 0.0530 0.0742 0.0409 0.0061 0.0125 0.0084 0.0089 0.0048 0.0094
## 3 0.2111 0.0176 0.1348 0.0744 0.0130 0.0106 0.0033 0.0232 0.0166 0.0095 0.0180
## 4 0.4295 0.3654 0.2655 0.1576 0.0681 0.0294 0.0241 0.0121 0.0036 0.0150 0.0085
## 5 0.0692 0.0528 0.0357 0.0085 0.0230 0.0046 0.0156 0.0031 0.0054 0.0105 0.0110
## 6 0.1192 0.1089 0.0623 0.0494 0.0264 0.0081 0.0104 0.0045 0.0014 0.0038 0.0013
##      V56    V57    V58    V59    V60 Class
## 1 0.0167 0.0180 0.0084 0.0090 0.0032     R
## 2 0.0191 0.0140 0.0049 0.0052 0.0044     R
## 3 0.0244 0.0316 0.0164 0.0095 0.0078     R
## 4 0.0073 0.0050 0.0044 0.0040 0.0117     R
## 5 0.0015 0.0072 0.0048 0.0107 0.0094     R
## 6 0.0089 0.0057 0.0027 0.0051 0.0062     R
library("rpart")
m <- rpart(Class ~ ., data = Sonar, method = "class")

library("rpart.plot")
## Warning: package 'rpart.plot' was built under R version 4.5.1
rpart.plot(m)

p <- predict(m, Sonar, type = "class")
table(p, Sonar$Class)
##    
## p    M  R
##   M 95 10
##   R 16 87

Training a random forest

library(caret)
## Loading required package: ggplot2
## Loading required package: lattice
set.seed(12)
model <- train(Class ~ .,
               data = Sonar,
               method = "ranger")
print(model)
## Random Forest 
## 
## 208 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 208, 208, 208, 208, 208, 208, ... 
## Resampling results across tuning parameters:
## 
##   mtry  splitrule   Accuracy   Kappa    
##    2    gini        0.8090731  0.6131571
##    2    extratrees  0.8136902  0.6234492
##   31    gini        0.7736954  0.5423516
##   31    extratrees  0.8285153  0.6521921
##   60    gini        0.7597299  0.5140905
##   60    extratrees  0.8157646  0.6255929
## 
## Tuning parameter 'min.node.size' was held constant at a value of 1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were mtry = 31, splitrule = extratrees
##  and min.node.size = 1.
plot(model)

model <- train(Class ~ ., data = Sonar, method = "ranger", tuneLength = 5)
set.seed(42)
myGrid <- expand.grid(mtry = c(5, 10, 20, 40, 60),
                      splitrule = c("gini", "extratrees"),
                      min.node.size = 1)
model <- train(Class ~ .,
               data = Sonar,
               method = "ranger",
               tuneGrid = myGrid,
               trControl = trainControl(method = "cv",
                                        number = 5,
                                        verboseIter = FALSE))
print(model)
## Random Forest 
## 
## 208 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 166, 167, 167, 167, 165 
## Resampling results across tuning parameters:
## 
##   mtry  splitrule   Accuracy   Kappa    
##    5    gini        0.8076277  0.6098253
##    5    extratrees  0.8416579  0.6784745
##   10    gini        0.7927667  0.5799348
##   10    extratrees  0.8418848  0.6791453
##   20    gini        0.7882316  0.5718852
##   20    extratrees  0.8516355  0.6991879
##   40    gini        0.7880048  0.5716461
##   40    extratrees  0.8371229  0.6695638
##   60    gini        0.7833482  0.5613525
##   60    extratrees  0.8322448  0.6599318
## 
## Tuning parameter 'min.node.size' was held constant at a value of 1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were mtry = 20, splitrule = extratrees
##  and min.node.size = 1.
plot(model)

Challenge

set.seed(42)
model <- train(Class ~ .,
               data = Sonar,
               method = "ranger",
               tuneLength = 5,
               trControl = trainControl(method = "cv",
                                        number = 5,
                                        verboseIter = FALSE))
plot(model)

STEP 2.

_ Randomly split your data into a training set (80%) and a test set (20%)

# check the data
chocolate <- read.csv("chocolate_tibble.csv")
head(chocolate)
##   final_grade review_date cocoa_percent company_location bean_type
## 1        3.75        2016          0.63           France          
## 2        2.75        2015          0.70           France          
## 3        3.00        2015          0.70           France          
## 4        3.50        2015          0.70           France          
## 5        3.50        2015          0.70           France          
## 6        2.75        2014          0.70           France   Criollo
##   broad_bean_origin
## 1          Sao Tome
## 2              Togo
## 3              Togo
## 4              Togo
## 5              Peru
## 6         Venezuela
library(caret)

set.seed(1234)
trainIndex <- createDataPartition(chocolate$final_grade, p = 0.8, list = FALSE)

chocolate_train <- chocolate[trainIndex,]
chocolate_test <- chocolate[-trainIndex, ]

decision_tree

library(parsnip)
## Warning: package 'parsnip' was built under R version 4.5.1
spec <- decision_tree() %>%
set_mode("regression") %>%
set_engine("rpart")
print(spec)
## Decision Tree Model Specification (regression)
## 
## Computational engine: rpart
model <- spec %>%
parsnip ::fit(formula = final_grade ~ ., data = chocolate_train)

library(rpart)
model2 <- rpart(final_grade ~ cocoa_percent + company_location, data=chocolate_train, method = "anova")

library(party)
## Warning: package 'party' was built under R version 4.5.1
## Loading required package: grid
## Loading required package: mvtnorm
## Loading required package: modeltools
## Loading required package: stats4
## 
## Attaching package: 'modeltools'
## The following object is masked from 'package:parsnip':
## 
##     fit
## Loading required package: strucchange
## Warning: package 'strucchange' was built under R version 4.5.1
## Loading required package: zoo
## 
## Attaching package: 'zoo'
## The following objects are masked from 'package:base':
## 
##     as.Date, as.Date.numeric
## Loading required package: sandwich
## Warning: package 'sandwich' was built under R version 4.5.1
model3 <- ctree(
  final_grade ~ cocoa_percent + as.factor(company_location),
  data= chocolate_train
)
library(rpart.plot)
rpart.plot(model2, box.palette = "RdBu", shadow.col ="grey", nn=TRUE)

plot(model3)

Hyperparameters

decision_tree(tree_depth = 1) %>%
set_mode("regression") %>%
set_engine("rpart") %>%
parsnip ::fit(formula = final_grade ~ ., data = chocolate_train)
## parsnip model object
## 
## n= 1438 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 1438 323.17940 3.187065  
##   2) cocoa_percent>=0.885 28  10.66741 2.401786 *
##   3) cocoa_percent< 0.885 1410 294.90250 3.202660 *

Random Forest

library(MASS)
data(package = "MASS")
boston <- Boston
dim(Boston)
## [1] 506  14
names(boston)
##  [1] "crim"    "zn"      "indus"   "chas"    "nox"     "rm"      "age"    
##  [8] "dis"     "rad"     "tax"     "ptratio" "black"   "lstat"   "medv"
# training sample with 300 observations
library(randomForest)
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
train = sample(1:nrow(Boston), 300)
Boston.rf = randomForest(medv ~ ., data = Boston, subset=train)
plot(Boston.rf)

# evaluating variable importance
importance(Boston.rf)
##         IncNodePurity
## crim        1835.3270
## zn           289.1031
## indus       1687.7424
## chas         188.9720
## nox         1297.0798
## rm          8042.0495
## age          628.4363
## dis         1019.2070
## rad          210.7702
## tax          875.1236
## ptratio     1734.5538
## black        471.1634
## lstat       6523.7743
varImpPlot(Boston.rf)