Machine learning is a branch of artificial intelligence (AI) focused on building applications that learn from data and improve their accuracy over time without being programmed to do so.
In data science, an algorithm is a sequence of statistical processing steps. In machine learning, algorithms are ‘trained’ to find patterns and features in massive amounts of data in order to make decisions and predictions based on new data. The better the algorithm, the more accurate the decisions and predictions will become as it processes more data.
There are many Machine Learning algorithms used for different types of problems.
The big question is: How do we know which is the best algorithm for our problem?
In this opportunity, I decided to analyze Random Forest and CatBoost.
The idea is to try to find out which of these algorithms work better on binary classification problems.
To find out, I applied the Machine Learning workflow on three datasets with different types of variables to analyze how these algorithms work by comparing the models.
Random Forest and CatBoost are Machine Learning algorithms used for classification and regression problems.
Random Forest uses the bagging technique. It is an ensemble method that consists in generating many little decision trees taking different random samples of the original dataset. Each decision tree makes its own prediction, which are combined to generate a much more accurate prediction.
CatBoost uses the boosting technique. It is also an ensemble method, but it consists in generating decision trees one after another, where the results of one tree are used to improve the next one, and so on.
Advantages of Random Forest over CatBoost:
Advantages of CatBoost over Random Forest:
In this section, I describe the phases of the Machine Learning workflow.
The first step was to choose three different datsets, one with categorical variables, one with numerical varibales and the other with both numerical and categorical variables, called mix.
#Numerical dataset
dataset_num <- read_excel("rice.xlsx")
#Categorical dataset
dataset_cat <- read.csv("mushrooms.csv")
#Mix dataset
dataset_mix <- read_excel("bank.xlsx")
In this phase I didn´t make big changes since the objective of the project is to analize how the algorithms works on the different datasets, not to obtain the best predictions. Therefore, I decided to remove ten of the most important variables from the categorical dataset, in order to add more complexity to it.
dataset_cat <- dataset_cat %>% select(-VEIL.TYPE,-STALK.ROOT,-ODOR,-SPORE.PRINT.COLOR,-GILL.COLOR,-GILL.SIZE,-HABITAT,-POPULATION,-STALK.SURFACE.ABOVE.RING,-CAP.COLOR,-RING.TYPE,-STALK.SURFACE.BELOW.RING)
In addition, the attributes of the type “Character” were converted to “Factor” so that they can be used by CatBoost.
dataset_num$CLASS <- as.factor(dataset_num$CLASS)
dataset_cat <- mutate_if(dataset_cat, is.character, as.factor)
dataset_mix <- mutate_if(dataset_mix, is.character, as.factor)
To train the models in the different datasets I defined two functions, one for each algorithm. Both were trained using the functions of the caret package, so there wasn´t any difference between them. For the same reason, I didn´t tune the hyperparameters. I applied a CrossValidation of five folds repeated two times and saved the results for later analysis.
Before training I split the dataset in two parts, leaving 80% for training and the other 20% for testing. This was made to have new data to test the final model.
Once the models were trained, I compared them with the metrics “Accuracy” and “Kappa”. These are the default metrics used to evaluate algorithms on binary and multi-class classification datasets in caret. Accuracy is the percentage of correctly classifies instances out of all instances and Kappa or Cohen’s Kappa is like classification accuracy, except that it is normalized at the baseline of random chance on your dataset.
Then, I applied a statistical test which returns a matrix with two values. The upper diagonal value represents the difference between the mean accuracy of the models and the lower diagonal represents the p-value, which is a probability, so it oscillates between 0 and 1. The p-value shows us the probability of having obtained the result that we have obtained assuming that the null hypothesis H0 is true. In this case, the hypothesis (H0) is that there is no difference (difference = 0) between the models. High values of p do not allow rejecting H0, while low values of p does.
#CATBOOST
train_cb_model <- function(data_train){
fitControl <- trainControl(method="repeatedcv",
repeats = 2,
number = 5,
returnResamp = 'final',
savePredictions = 'final',
verboseIter = T,
allowParallel = T)
catboost_model <- train(
x = data_train[,!(names(data_train) %in% c("CLASS"))],
y = data_train$CLASS,
method = catboost.caret,
trControl = fitControl)
return(catboost_model)
}
#RANDOM FOREST
train_rf_model <- function(data_train){
fitControl <- trainControl(method="repeatedcv",
repeats = 2,
number = 5,
returnResamp = 'final',
savePredictions = 'final',
verboseIter = T,
allowParallel = T)
train_formula<-formula(CLASS~.)
rf_model <- train(train_formula,
data = data_train,
method = "rf",
trControl = fitControl)
return(rf_model)
}
dim(data_train_num)
## [1] 3048 8
dim(data_test_num)
## [1] 762 8
## Aggregating results
## Selecting tuning parameters
## Fitting depth = 2, learning_rate = 0.0498, iterations = 100, l2_leaf_reg = 1e-06, rsm = 0.9, border_count = 255 on full training set
## 0: learn: 0.6569683 total: 157ms remaining: 15.6s
## 1: learn: 0.6242333 total: 159ms remaining: 7.78s
## 2: learn: 0.5951390 total: 160ms remaining: 5.18s
## 3: learn: 0.5678011 total: 162ms remaining: 3.88s
## 4: learn: 0.5431128 total: 164ms remaining: 3.11s
## 5: learn: 0.5211146 total: 165ms remaining: 2.58s
## 6: learn: 0.5000001 total: 168ms remaining: 2.23s
## 7: learn: 0.4801153 total: 169ms remaining: 1.94s
## 8: learn: 0.4618915 total: 171ms remaining: 1.72s
## 9: learn: 0.4461979 total: 172ms remaining: 1.55s
## 10: learn: 0.4310699 total: 174ms remaining: 1.4s
## 11: learn: 0.4164848 total: 175ms remaining: 1.28s
## 12: learn: 0.4033285 total: 176ms remaining: 1.18s
## 13: learn: 0.3912534 total: 178ms remaining: 1.09s
## 14: learn: 0.3796283 total: 179ms remaining: 1.01s
## 15: learn: 0.3690259 total: 180ms remaining: 948ms
## 16: learn: 0.3590317 total: 183ms remaining: 893ms
## 17: learn: 0.3491621 total: 184ms remaining: 839ms
## 18: learn: 0.3401293 total: 186ms remaining: 792ms
## 19: learn: 0.3318941 total: 188ms remaining: 751ms
## 20: learn: 0.3244586 total: 189ms remaining: 712ms
## 21: learn: 0.3178318 total: 191ms remaining: 676ms
## 22: learn: 0.3107482 total: 192ms remaining: 643ms
## 23: learn: 0.3045499 total: 194ms remaining: 613ms
## 24: learn: 0.2987369 total: 195ms remaining: 584ms
## 25: learn: 0.2933932 total: 196ms remaining: 558ms
## 26: learn: 0.2882655 total: 198ms remaining: 535ms
## 27: learn: 0.2835598 total: 199ms remaining: 512ms
## 28: learn: 0.2792432 total: 201ms remaining: 491ms
## 29: learn: 0.2739695 total: 202ms remaining: 471ms
## 30: learn: 0.2702704 total: 204ms remaining: 453ms
## 31: learn: 0.2664862 total: 205ms remaining: 435ms
## 32: learn: 0.2623068 total: 206ms remaining: 419ms
## 33: learn: 0.2583029 total: 208ms remaining: 403ms
## 34: learn: 0.2548068 total: 209ms remaining: 389ms
## 35: learn: 0.2517332 total: 211ms remaining: 374ms
## 36: learn: 0.2486427 total: 212ms remaining: 361ms
## 37: learn: 0.2457124 total: 213ms remaining: 348ms
## 38: learn: 0.2427640 total: 215ms remaining: 336ms
## 39: learn: 0.2400924 total: 216ms remaining: 324ms
## 40: learn: 0.2376467 total: 218ms remaining: 313ms
## 41: learn: 0.2354592 total: 221ms remaining: 305ms
## 42: learn: 0.2332138 total: 222ms remaining: 295ms
## 43: learn: 0.2307522 total: 224ms remaining: 285ms
## 44: learn: 0.2283364 total: 225ms remaining: 275ms
## 45: learn: 0.2264008 total: 228ms remaining: 267ms
## 46: learn: 0.2246919 total: 229ms remaining: 258ms
## 47: learn: 0.2231454 total: 231ms remaining: 250ms
## 48: learn: 0.2216832 total: 232ms remaining: 242ms
## 49: learn: 0.2198429 total: 234ms remaining: 234ms
## 50: learn: 0.2180264 total: 236ms remaining: 227ms
## 51: learn: 0.2168176 total: 237ms remaining: 219ms
## 52: learn: 0.2152851 total: 239ms remaining: 212ms
## 53: learn: 0.2145536 total: 240ms remaining: 205ms
## 54: learn: 0.2132660 total: 242ms remaining: 198ms
## 55: learn: 0.2126548 total: 243ms remaining: 191ms
## 56: learn: 0.2115142 total: 244ms remaining: 184ms
## 57: learn: 0.2104753 total: 247ms remaining: 179ms
## 58: learn: 0.2095229 total: 248ms remaining: 172ms
## 59: learn: 0.2089526 total: 249ms remaining: 166ms
## 60: learn: 0.2081098 total: 251ms remaining: 161ms
## 61: learn: 0.2071999 total: 253ms remaining: 155ms
## 62: learn: 0.2062766 total: 254ms remaining: 149ms
## 63: learn: 0.2057835 total: 256ms remaining: 144ms
## 64: learn: 0.2052168 total: 257ms remaining: 138ms
## 65: learn: 0.2045010 total: 258ms remaining: 133ms
## 66: learn: 0.2036931 total: 260ms remaining: 128ms
## 67: learn: 0.2030185 total: 262ms remaining: 123ms
## 68: learn: 0.2023845 total: 263ms remaining: 118ms
## 69: learn: 0.2019670 total: 264ms remaining: 113ms
## 70: learn: 0.2013101 total: 266ms remaining: 109ms
## 71: learn: 0.2007368 total: 267ms remaining: 104ms
## 72: learn: 0.2002910 total: 269ms remaining: 99.3ms
## 73: learn: 0.1996998 total: 270ms remaining: 94.8ms
## 74: learn: 0.1991904 total: 271ms remaining: 90.4ms
## 75: learn: 0.1986754 total: 273ms remaining: 86.1ms
## 76: learn: 0.1982854 total: 274ms remaining: 81.8ms
## 77: learn: 0.1978827 total: 275ms remaining: 77.7ms
## 78: learn: 0.1972778 total: 277ms remaining: 73.6ms
## 79: learn: 0.1968828 total: 278ms remaining: 69.5ms
## 80: learn: 0.1965375 total: 279ms remaining: 65.5ms
## 81: learn: 0.1963981 total: 281ms remaining: 61.6ms
## 82: learn: 0.1961863 total: 282ms remaining: 57.8ms
## 83: learn: 0.1957577 total: 283ms remaining: 54ms
## 84: learn: 0.1955127 total: 285ms remaining: 50.3ms
## 85: learn: 0.1952174 total: 286ms remaining: 46.6ms
## 86: learn: 0.1949471 total: 288ms remaining: 43ms
## 87: learn: 0.1946723 total: 289ms remaining: 39.4ms
## 88: learn: 0.1942114 total: 290ms remaining: 35.9ms
## 89: learn: 0.1939321 total: 292ms remaining: 32.4ms
## 90: learn: 0.1937391 total: 293ms remaining: 29ms
## 91: learn: 0.1935212 total: 294ms remaining: 25.6ms
## 92: learn: 0.1933053 total: 296ms remaining: 22.3ms
## 93: learn: 0.1931373 total: 298ms remaining: 19ms
## 94: learn: 0.1927786 total: 300ms remaining: 15.8ms
## 95: learn: 0.1924907 total: 302ms remaining: 12.6ms
## 96: learn: 0.1922449 total: 303ms remaining: 9.38ms
## 97: learn: 0.1921018 total: 305ms remaining: 6.22ms
## 98: learn: 0.1919528 total: 307ms remaining: 3.1ms
## 99: learn: 0.1917510 total: 309ms remaining: 0us
## Catboost
##
## 3048 samples
## 7 predictor
## 2 classes: 'Cammeo', 'Osmancik'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 2 times)
## Summary of sample sizes: 2438, 2439, 2438, 2438, 2439, 2438, ...
## Resampling results across tuning parameters:
##
## depth learning_rate Accuracy Kappa
## 2 0.04978707 0.9271682 0.8509679
## 2 0.13533528 0.9256919 0.8479776
## 2 0.36787944 0.9206078 0.8376340
## 2 1.00000000 0.9081364 0.8121578
## 4 0.04978707 0.9263490 0.8492475
## 4 0.13533528 0.9255277 0.8475072
## 4 0.36787944 0.9087951 0.8134313
## 4 1.00000000 0.9005906 0.7967269
## 6 0.04978707 0.9261848 0.8489097
## 6 0.13533528 0.9229016 0.8422417
## 6 0.36787944 0.9081402 0.8122341
## 6 1.00000000 0.8925605 0.7804570
##
## Tuning parameter 'iterations' was held constant at a value of 100
##
## Tuning parameter 'rsm' was held constant at a value of 0.9
## Tuning
## parameter 'border_count' was held constant at a value of 255
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were depth = 2, learning_rate =
## 0.04978707, iterations = 100, l2_leaf_reg = 1e-06, rsm = 0.9 and
## border_count = 255.
## user system elapsed
## 1.47 0.14 59.97
## Aggregating results
## Selecting tuning parameters
## Fitting mtry = 2 on full training set
## Random Forest
##
## 3048 samples
## 7 predictor
## 2 classes: 'Cammeo', 'Osmancik'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 2 times)
## Summary of sample sizes: 2438, 2440, 2438, 2438, 2438, 2439, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.9210980 0.8385808
## 4 0.9194573 0.8351830
## 7 0.9206027 0.8375078
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
## user system elapsed
## 1.96 0.04 21.56
##
## Call:
## resamples.default(x = list(cb_num = catboost_model_num, rf_num = rf_model_num))
##
## Models: cb_num, rf_num
## Number of resamples: 10
## Performance metrics: Accuracy, Kappa
## Time estimates for: everything, final model fit
##
## Call:
## summary.resamples(object = resamps_num)
##
## Models: cb_num, rf_num
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## cb_num 0.9081967 0.9249697 0.9261689 0.9271682 0.9336066 0.9376026 0
## rf_num 0.9098361 0.9147541 0.9154363 0.9210980 0.9298332 0.9377049 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## cb_num 0.8113978 0.8459288 0.8493793 0.8509679 0.8638352 0.8727315 0
## rf_num 0.8155722 0.8250819 0.8278092 0.8385808 0.8559395 0.8716458 0
##
## Call:
## diff.resamples(x = resamps_num)
##
## Models: cb_num, rf_num
## Metrics: Accuracy, Kappa
## Number of differences: 1
## p-value adjustment: bonferroni
##
## Call:
## summary.diff.resamples(object = difValues_num)
##
## p-value adjustment: bonferroni
## Upper diagonal: estimates of the difference
## Lower diagonal: p-value for H0: difference = 0
##
## Accuracy
## cb_num rf_num
## cb_num 0.00607
## rf_num 0.2773
##
## Kappa
## cb_num rf_num
## cb_num 0.01239
## rf_num 0.2761
dim(data_train_cat)
## [1] 6500 11
dim(data_test_cat)
## [1] 1624 11
## Aggregating results
## Selecting tuning parameters
## Fitting depth = 6, learning_rate = 0.368, iterations = 100, l2_leaf_reg = 1e-06, rsm = 0.9, border_count = 255 on full training set
## 0: learn: 0.4994800 total: 7.54ms remaining: 746ms
## 1: learn: 0.4096574 total: 16.5ms remaining: 809ms
## 2: learn: 0.3647601 total: 25.6ms remaining: 828ms
## 3: learn: 0.3014888 total: 31.6ms remaining: 760ms
## 4: learn: 0.2612795 total: 37.8ms remaining: 718ms
## 5: learn: 0.2208003 total: 43.9ms remaining: 688ms
## 6: learn: 0.1968160 total: 50.1ms remaining: 665ms
## 7: learn: 0.1765618 total: 56.7ms remaining: 651ms
## 8: learn: 0.1622171 total: 62.6ms remaining: 633ms
## 9: learn: 0.1557345 total: 69ms remaining: 621ms
## 10: learn: 0.1485035 total: 75ms remaining: 607ms
## 11: learn: 0.1359723 total: 81.9ms remaining: 601ms
## 12: learn: 0.1306887 total: 88.3ms remaining: 591ms
## 13: learn: 0.1273277 total: 94.2ms remaining: 579ms
## 14: learn: 0.1139660 total: 100ms remaining: 567ms
## 15: learn: 0.1103210 total: 106ms remaining: 557ms
## 16: learn: 0.1084042 total: 118ms remaining: 577ms
## 17: learn: 0.1069224 total: 125ms remaining: 568ms
## 18: learn: 0.1001328 total: 131ms remaining: 560ms
## 19: learn: 0.0996894 total: 138ms remaining: 553ms
## 20: learn: 0.0973840 total: 145ms remaining: 544ms
## 21: learn: 0.0969426 total: 150ms remaining: 533ms
## 22: learn: 0.0939062 total: 156ms remaining: 521ms
## 23: learn: 0.0905961 total: 162ms remaining: 515ms
## 24: learn: 0.0881185 total: 169ms remaining: 508ms
## 25: learn: 0.0803296 total: 175ms remaining: 499ms
## 26: learn: 0.0770440 total: 181ms remaining: 490ms
## 27: learn: 0.0756779 total: 187ms remaining: 482ms
## 28: learn: 0.0751775 total: 193ms remaining: 474ms
## 29: learn: 0.0743372 total: 200ms remaining: 466ms
## 30: learn: 0.0710926 total: 205ms remaining: 457ms
## 31: learn: 0.0708960 total: 211ms remaining: 449ms
## 32: learn: 0.0706076 total: 217ms remaining: 441ms
## 33: learn: 0.0703579 total: 223ms remaining: 432ms
## 34: learn: 0.0677861 total: 231ms remaining: 428ms
## 35: learn: 0.0675917 total: 236ms remaining: 420ms
## 36: learn: 0.0673429 total: 242ms remaining: 412ms
## 37: learn: 0.0667790 total: 249ms remaining: 406ms
## 38: learn: 0.0655446 total: 254ms remaining: 397ms
## 39: learn: 0.0652196 total: 260ms remaining: 390ms
## 40: learn: 0.0648127 total: 266ms remaining: 383ms
## 41: learn: 0.0644545 total: 272ms remaining: 376ms
## 42: learn: 0.0643491 total: 279ms remaining: 370ms
## 43: learn: 0.0640654 total: 285ms remaining: 363ms
## 44: learn: 0.0639348 total: 291ms remaining: 356ms
## 45: learn: 0.0637873 total: 297ms remaining: 348ms
## 46: learn: 0.0637396 total: 302ms remaining: 341ms
## 47: learn: 0.0636433 total: 311ms remaining: 337ms
## 48: learn: 0.0633741 total: 318ms remaining: 331ms
## 49: learn: 0.0632656 total: 324ms remaining: 324ms
## 50: learn: 0.0624294 total: 330ms remaining: 317ms
## 51: learn: 0.0623264 total: 336ms remaining: 310ms
## 52: learn: 0.0622712 total: 342ms remaining: 303ms
## 53: learn: 0.0622273 total: 348ms remaining: 297ms
## 54: learn: 0.0621845 total: 354ms remaining: 289ms
## 55: learn: 0.0621474 total: 359ms remaining: 282ms
## 56: learn: 0.0618605 total: 365ms remaining: 276ms
## 57: learn: 0.0617126 total: 371ms remaining: 269ms
## 58: learn: 0.0611893 total: 380ms remaining: 264ms
## 59: learn: 0.0610971 total: 386ms remaining: 257ms
## 60: learn: 0.0606582 total: 392ms remaining: 251ms
## 61: learn: 0.0606306 total: 398ms remaining: 244ms
## 62: learn: 0.0604932 total: 404ms remaining: 237ms
## 63: learn: 0.0600339 total: 410ms remaining: 231ms
## 64: learn: 0.0599085 total: 416ms remaining: 224ms
## 65: learn: 0.0598657 total: 422ms remaining: 217ms
## 66: learn: 0.0598437 total: 428ms remaining: 211ms
## 67: learn: 0.0598235 total: 433ms remaining: 204ms
## 68: learn: 0.0597535 total: 439ms remaining: 197ms
## 69: learn: 0.0592773 total: 445ms remaining: 191ms
## 70: learn: 0.0590662 total: 451ms remaining: 184ms
## 71: learn: 0.0590257 total: 457ms remaining: 178ms
## 72: learn: 0.0589155 total: 463ms remaining: 171ms
## 73: learn: 0.0588631 total: 469ms remaining: 165ms
## 74: learn: 0.0586900 total: 475ms remaining: 158ms
## 75: learn: 0.0586128 total: 481ms remaining: 152ms
## 76: learn: 0.0585740 total: 487ms remaining: 146ms
## 77: learn: 0.0584995 total: 494ms remaining: 139ms
## 78: learn: 0.0584375 total: 500ms remaining: 133ms
## 79: learn: 0.0583507 total: 506ms remaining: 127ms
## 80: learn: 0.0583332 total: 512ms remaining: 120ms
## 81: learn: 0.0583288 total: 518ms remaining: 114ms
## 82: learn: 0.0582956 total: 527ms remaining: 108ms
## 83: learn: 0.0582765 total: 533ms remaining: 102ms
## 84: learn: 0.0582282 total: 540ms remaining: 95.4ms
## 85: learn: 0.0582177 total: 547ms remaining: 89ms
## 86: learn: 0.0580358 total: 553ms remaining: 82.6ms
## 87: learn: 0.0580295 total: 559ms remaining: 76.2ms
## 88: learn: 0.0580211 total: 565ms remaining: 69.8ms
## 89: learn: 0.0580190 total: 571ms remaining: 63.4ms
## 90: learn: 0.0578438 total: 577ms remaining: 57.1ms
## 91: learn: 0.0576322 total: 583ms remaining: 50.7ms
## 92: learn: 0.0576166 total: 589ms remaining: 44.3ms
## 93: learn: 0.0575714 total: 594ms remaining: 37.9ms
## 94: learn: 0.0575552 total: 600ms remaining: 31.6ms
## 95: learn: 0.0575456 total: 606ms remaining: 25.2ms
## 96: learn: 0.0575436 total: 611ms remaining: 18.9ms
## 97: learn: 0.0574634 total: 618ms remaining: 12.6ms
## 98: learn: 0.0574011 total: 625ms remaining: 6.31ms
## 99: learn: 0.0573906 total: 631ms remaining: 0us
## Catboost
##
## 6500 samples
## 10 predictor
## 2 classes: 'e', 'p'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 2 times)
## Summary of sample sizes: 5200, 5199, 5201, 5200, 5200, 5200, ...
## Resampling results across tuning parameters:
##
## depth learning_rate Accuracy Kappa
## 2 0.04978707 0.8929995 0.7846000
## 2 0.13533528 0.9309997 0.8617534
## 2 0.36787944 0.9438448 0.8875377
## 2 1.00000000 0.9572300 0.9143279
## 4 0.04978707 0.9406928 0.8810092
## 4 0.13533528 0.9622300 0.9242760
## 4 0.36787944 0.9667692 0.9333714
## 4 1.00000000 0.9175499 0.8355423
## 6 0.04978707 0.9641527 0.9280915
## 6 0.13533528 0.9669232 0.9336809
## 6 0.36787944 0.9676154 0.9350553
## 6 1.00000000 0.9173735 0.8358280
##
## Tuning parameter 'iterations' was held constant at a value of 100
##
## Tuning parameter 'rsm' was held constant at a value of 0.9
## Tuning
## parameter 'border_count' was held constant at a value of 255
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were depth = 6, learning_rate =
## 0.3678794, iterations = 100, l2_leaf_reg = 1e-06, rsm = 0.9 and border_count
## = 255.
## user system elapsed
## 3.75 0.31 63.82
## Aggregating results
## Selecting tuning parameters
## Fitting mtry = 17 on full training set
## Random Forest
##
## 6500 samples
## 10 predictor
## 2 classes: 'e', 'p'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 2 times)
## Summary of sample sizes: 5199, 5199, 5201, 5200, 5201, 5199, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.8514578 0.7003363
## 17 0.9688446 0.9375317
## 33 0.9687678 0.9373769
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 17.
## user system elapsed
## 7.08 0.05 102.06
##
## Call:
## resamples.default(x = list(cb_cat = catboost_model_cat, rf_cat = rf_model_cat))
##
## Models: cb_cat, rf_cat
## Number of resamples: 10
## Performance metrics: Accuracy, Kappa
## Time estimates for: everything, final model fit
##
## Call:
## summary.resamples(object = resamps_cat)
##
## Models: cb_cat, rf_cat
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## cb_cat 0.9638462 0.9663657 0.9680523 0.9676154 0.9690268 0.9700231 0
## rf_cat 0.9615089 0.9678846 0.9696035 0.9688446 0.9705938 0.9746349 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## cb_cat 0.9274870 0.9325299 0.9359293 0.9350553 0.9378695 0.9399110 0
## rf_cat 0.9227876 0.9355937 0.9390468 0.9375317 0.9410406 0.9491611 0
##
## Call:
## diff.resamples(x = resamps_cat)
##
## Models: cb_cat, rf_cat
## Metrics: Accuracy, Kappa
## Number of differences: 1
## p-value adjustment: bonferroni
##
## Call:
## summary.diff.resamples(object = difValues_cat)
##
## p-value adjustment: bonferroni
## Upper diagonal: estimates of the difference
## Lower diagonal: p-value for H0: difference = 0
##
## Accuracy
## cb_cat rf_cat
## cb_cat -0.001229
## rf_cat 0.404
##
## Kappa
## cb_cat rf_cat
## cb_cat -0.002476
## rf_cat 0.4026
dim(data_train_mix)
## [1] 3617 17
dim(data_test_mix)
## [1] 904 17
## Aggregating results
## Selecting tuning parameters
## Fitting depth = 4, learning_rate = 0.0498, iterations = 100, l2_leaf_reg = 1e-06, rsm = 0.9, border_count = 255 on full training set
## 0: learn: 0.6619145 total: 6.79ms remaining: 672ms
## 1: learn: 0.6333389 total: 14.8ms remaining: 723ms
## 2: learn: 0.6070859 total: 18.2ms remaining: 590ms
## 3: learn: 0.5829853 total: 23.5ms remaining: 564ms
## 4: learn: 0.5606205 total: 27ms remaining: 512ms
## 5: learn: 0.5408329 total: 30.2ms remaining: 473ms
## 6: learn: 0.5227003 total: 33.7ms remaining: 447ms
## 7: learn: 0.5051049 total: 36.7ms remaining: 422ms
## 8: learn: 0.4892239 total: 40ms remaining: 404ms
## 9: learn: 0.4740411 total: 43ms remaining: 387ms
## 10: learn: 0.4609471 total: 47.2ms remaining: 382ms
## 11: learn: 0.4488313 total: 50.2ms remaining: 368ms
## 12: learn: 0.4365045 total: 54.3ms remaining: 363ms
## 13: learn: 0.4253186 total: 57.6ms remaining: 354ms
## 14: learn: 0.4160681 total: 61ms remaining: 346ms
## 15: learn: 0.4065764 total: 64.7ms remaining: 340ms
## 16: learn: 0.3965680 total: 68.1ms remaining: 333ms
## 17: learn: 0.3886832 total: 71.3ms remaining: 325ms
## 18: learn: 0.3808978 total: 74.6ms remaining: 318ms
## 19: learn: 0.3736135 total: 77.9ms remaining: 312ms
## 20: learn: 0.3669018 total: 82ms remaining: 309ms
## 21: learn: 0.3607049 total: 84.9ms remaining: 301ms
## 22: learn: 0.3532668 total: 88.4ms remaining: 296ms
## 23: learn: 0.3460694 total: 91.7ms remaining: 290ms
## 24: learn: 0.3399944 total: 95.2ms remaining: 286ms
## 25: learn: 0.3341293 total: 98.7ms remaining: 281ms
## 26: learn: 0.3286615 total: 102ms remaining: 276ms
## 27: learn: 0.3235016 total: 105ms remaining: 270ms
## 28: learn: 0.3190446 total: 108ms remaining: 265ms
## 29: learn: 0.3147219 total: 112ms remaining: 261ms
## 30: learn: 0.3103873 total: 115ms remaining: 256ms
## 31: learn: 0.3067219 total: 118ms remaining: 251ms
## 32: learn: 0.3039403 total: 122ms remaining: 247ms
## 33: learn: 0.3004879 total: 125ms remaining: 242ms
## 34: learn: 0.2970095 total: 129ms remaining: 240ms
## 35: learn: 0.2940955 total: 133ms remaining: 237ms
## 36: learn: 0.2913947 total: 137ms remaining: 233ms
## 37: learn: 0.2886820 total: 141ms remaining: 229ms
## 38: learn: 0.2862199 total: 144ms remaining: 226ms
## 39: learn: 0.2837852 total: 148ms remaining: 221ms
## 40: learn: 0.2816963 total: 151ms remaining: 217ms
## 41: learn: 0.2791939 total: 154ms remaining: 213ms
## 42: learn: 0.2768814 total: 157ms remaining: 208ms
## 43: learn: 0.2749439 total: 162ms remaining: 206ms
## 44: learn: 0.2730512 total: 165ms remaining: 202ms
## 45: learn: 0.2717200 total: 169ms remaining: 198ms
## 46: learn: 0.2698651 total: 172ms remaining: 194ms
## 47: learn: 0.2679184 total: 176ms remaining: 190ms
## 48: learn: 0.2660879 total: 179ms remaining: 187ms
## 49: learn: 0.2644605 total: 183ms remaining: 183ms
## 50: learn: 0.2629189 total: 186ms remaining: 178ms
## 51: learn: 0.2605908 total: 189ms remaining: 175ms
## 52: learn: 0.2590592 total: 193ms remaining: 171ms
## 53: learn: 0.2569017 total: 197ms remaining: 168ms
## 54: learn: 0.2552873 total: 200ms remaining: 164ms
## 55: learn: 0.2534356 total: 204ms remaining: 160ms
## 56: learn: 0.2519037 total: 207ms remaining: 157ms
## 57: learn: 0.2501195 total: 211ms remaining: 153ms
## 58: learn: 0.2492585 total: 214ms remaining: 148ms
## 59: learn: 0.2477592 total: 217ms remaining: 145ms
## 60: learn: 0.2462444 total: 220ms remaining: 141ms
## 61: learn: 0.2451697 total: 224ms remaining: 137ms
## 62: learn: 0.2438277 total: 228ms remaining: 134ms
## 63: learn: 0.2425957 total: 231ms remaining: 130ms
## 64: learn: 0.2414945 total: 234ms remaining: 126ms
## 65: learn: 0.2404411 total: 238ms remaining: 123ms
## 66: learn: 0.2397621 total: 242ms remaining: 119ms
## 67: learn: 0.2387321 total: 245ms remaining: 115ms
## 68: learn: 0.2379193 total: 249ms remaining: 112ms
## 69: learn: 0.2371562 total: 252ms remaining: 108ms
## 70: learn: 0.2360392 total: 256ms remaining: 104ms
## 71: learn: 0.2352913 total: 259ms remaining: 101ms
## 72: learn: 0.2348532 total: 262ms remaining: 96.9ms
## 73: learn: 0.2342586 total: 265ms remaining: 93.2ms
## 74: learn: 0.2337199 total: 271ms remaining: 90.4ms
## 75: learn: 0.2331205 total: 275ms remaining: 86.8ms
## 76: learn: 0.2324484 total: 279ms remaining: 83.2ms
## 77: learn: 0.2320667 total: 282ms remaining: 79.7ms
## 78: learn: 0.2313027 total: 286ms remaining: 75.9ms
## 79: learn: 0.2306441 total: 290ms remaining: 72.4ms
## 80: learn: 0.2300945 total: 293ms remaining: 68.7ms
## 81: learn: 0.2296393 total: 297ms remaining: 65.1ms
## 82: learn: 0.2288324 total: 300ms remaining: 61.4ms
## 83: learn: 0.2284424 total: 303ms remaining: 57.7ms
## 84: learn: 0.2276909 total: 307ms remaining: 54.2ms
## 85: learn: 0.2269461 total: 310ms remaining: 50.5ms
## 86: learn: 0.2261416 total: 314ms remaining: 46.9ms
## 87: learn: 0.2259795 total: 317ms remaining: 43.2ms
## 88: learn: 0.2253600 total: 320ms remaining: 39.6ms
## 89: learn: 0.2249566 total: 324ms remaining: 36ms
## 90: learn: 0.2243119 total: 327ms remaining: 32.4ms
## 91: learn: 0.2240410 total: 331ms remaining: 28.7ms
## 92: learn: 0.2236404 total: 334ms remaining: 25.2ms
## 93: learn: 0.2232068 total: 340ms remaining: 21.7ms
## 94: learn: 0.2225524 total: 344ms remaining: 18.1ms
## 95: learn: 0.2222128 total: 347ms remaining: 14.5ms
## 96: learn: 0.2219318 total: 351ms remaining: 10.8ms
## 97: learn: 0.2214830 total: 354ms remaining: 7.23ms
## 98: learn: 0.2208208 total: 358ms remaining: 3.62ms
## 99: learn: 0.2203946 total: 361ms remaining: 0us
## Catboost
##
## 3617 samples
## 16 predictor
## 2 classes: 'no', 'yes'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 2 times)
## Summary of sample sizes: 2894, 2893, 2894, 2893, 2894, 2894, ...
## Resampling results across tuning parameters:
##
## depth learning_rate Accuracy Kappa
## 2 0.04978707 0.8950767 0.2871657
## 2 0.13533528 0.8992261 0.3917631
## 2 0.36787944 0.8992244 0.4207772
## 2 1.00000000 0.8927262 0.4115878
## 4 0.04978707 0.8993631 0.3597537
## 4 0.13533528 0.8971508 0.4049038
## 4 0.36787944 0.8957685 0.4228357
## 4 1.00000000 0.8829142 0.3720022
## 6 0.04978707 0.8990872 0.3792960
## 6 0.13533528 0.8979813 0.4131577
## 6 0.36787944 0.8939706 0.4162280
## 6 1.00000000 0.8661843 0.3251594
##
## Tuning parameter 'iterations' was held constant at a value of 100
##
## Tuning parameter 'rsm' was held constant at a value of 0.9
## Tuning
## parameter 'border_count' was held constant at a value of 255
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were depth = 4, learning_rate =
## 0.04978707, iterations = 100, l2_leaf_reg = 1e-06, rsm = 0.9 and
## border_count = 255.
## user system elapsed
## 2.52 0.06 58.14
## Aggregating results
## Selecting tuning parameters
## Fitting mtry = 22 on full training set
## Random Forest
##
## 3617 samples
## 16 predictor
## 2 classes: 'no', 'yes'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 2 times)
## Summary of sample sizes: 2894, 2894, 2893, 2893, 2894, 2894, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.8860940 0.03645111
## 22 0.8981215 0.41655765
## 42 0.8964627 0.41371675
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 22.
## user system elapsed
## 7.80 0.06 117.17
##
## Call:
## resamples.default(x = list(cb_mix = catboost_model_mix, rf_mix = rf_model_mix))
##
## Models: cb_mix, rf_mix
## Number of resamples: 10
## Performance metrics: Accuracy, Kappa
## Time estimates for: everything, final model fit
##
## Call:
## summary.resamples(object = resamps_mix)
##
## Models: cb_mix, rf_mix
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## cb_mix 0.8865837 0.8977901 0.9024896 0.8993631 0.9031812 0.9060773 0
## rf_mix 0.8879668 0.8907331 0.8964088 0.8981215 0.9025240 0.9170124 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## cb_mix 0.2746293 0.3545218 0.3638693 0.3597537 0.3835842 0.4072433 0
## rf_mix 0.3415528 0.3876035 0.4102525 0.4165577 0.4480168 0.5272347 0
##
## Call:
## diff.resamples(x = resamps_mix)
##
## Models: cb_mix, rf_mix
## Metrics: Accuracy, Kappa
## Number of differences: 1
## p-value adjustment: bonferroni
##
## Call:
## summary.diff.resamples(object = difValues_mix)
##
## p-value adjustment: bonferroni
## Upper diagonal: estimates of the difference
## Lower diagonal: p-value for H0: difference = 0
##
## Accuracy
## cb_mix rf_mix
## cb_mix 0.001242
## rf_mix 0.7074
##
## Kappa
## cb_mix rf_mix
## cb_mix -0.0568
## rf_mix 0.01279
The last step is to test the data. Here, I predicted the classes of the unseen data that I`d saved for testing using the two different models. Then, I built a confusion matrix to see the results.
## Confusion Matrix and Statistics
##
## Reference
## Prediction Cammeo Osmancik
## Cammeo 301 27
## Osmancik 25 409
##
## Accuracy : 0.9318
## 95% CI : (0.9115, 0.9486)
## No Information Rate : 0.5722
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.8607
##
## Mcnemar's Test P-Value : 0.8897
##
## Sensitivity : 0.9233
## Specificity : 0.9381
## Pos Pred Value : 0.9177
## Neg Pred Value : 0.9424
## Prevalence : 0.4278
## Detection Rate : 0.3950
## Detection Prevalence : 0.4304
## Balanced Accuracy : 0.9307
##
## 'Positive' Class : Cammeo
##
## Confusion Matrix and Statistics
##
## Reference
## Prediction e p
## e 824 39
## p 17 744
##
## Accuracy : 0.9655
## 95% CI : (0.9555, 0.9738)
## No Information Rate : 0.5179
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9309
##
## Mcnemar's Test P-Value : 0.005012
##
## Sensitivity : 0.9798
## Specificity : 0.9502
## Pos Pred Value : 0.9548
## Neg Pred Value : 0.9777
## Prevalence : 0.5179
## Detection Rate : 0.5074
## Detection Prevalence : 0.5314
## Balanced Accuracy : 0.9650
##
## 'Positive' Class : e
##
## Confusion Matrix and Statistics
##
## Reference
## Prediction no yes
## no 789 83
## yes 11 21
##
## Accuracy : 0.896
## 95% CI : (0.8743, 0.9152)
## No Information Rate : 0.885
## P-Value [Acc > NIR] : 0.161
##
## Kappa : 0.2693
##
## Mcnemar's Test P-Value : 2.423e-13
##
## Sensitivity : 0.9862
## Specificity : 0.2019
## Pos Pred Value : 0.9048
## Neg Pred Value : 0.6562
## Prevalence : 0.8850
## Detection Rate : 0.8728
## Detection Prevalence : 0.9646
## Balanced Accuracy : 0.5941
##
## 'Positive' Class : no
##
## Confusion Matrix and Statistics
##
## Reference
## Prediction Cammeo Osmancik
## Cammeo 298 27
## Osmancik 28 409
##
## Accuracy : 0.9278
## 95% CI : (0.9071, 0.9452)
## No Information Rate : 0.5722
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.8525
##
## Mcnemar's Test P-Value : 1
##
## Sensitivity : 0.9141
## Specificity : 0.9381
## Pos Pred Value : 0.9169
## Neg Pred Value : 0.9359
## Prevalence : 0.4278
## Detection Rate : 0.3911
## Detection Prevalence : 0.4265
## Balanced Accuracy : 0.9261
##
## 'Positive' Class : Cammeo
##
## Confusion Matrix and Statistics
##
## Reference
## Prediction e p
## e 824 39
## p 17 744
##
## Accuracy : 0.9655
## 95% CI : (0.9555, 0.9738)
## No Information Rate : 0.5179
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9309
##
## Mcnemar's Test P-Value : 0.005012
##
## Sensitivity : 0.9798
## Specificity : 0.9502
## Pos Pred Value : 0.9548
## Neg Pred Value : 0.9777
## Prevalence : 0.5179
## Detection Rate : 0.5074
## Detection Prevalence : 0.5314
## Balanced Accuracy : 0.9650
##
## 'Positive' Class : e
##
## Confusion Matrix and Statistics
##
## Reference
## Prediction no yes
## no 780 70
## yes 20 34
##
## Accuracy : 0.9004
## 95% CI : (0.879, 0.9192)
## No Information Rate : 0.885
## P-Value [Acc > NIR] : 0.07758
##
## Kappa : 0.3818
##
## Mcnemar's Test P-Value : 2.404e-07
##
## Sensitivity : 0.9750
## Specificity : 0.3269
## Pos Pred Value : 0.9176
## Neg Pred Value : 0.6296
## Prevalence : 0.8850
## Detection Rate : 0.8628
## Detection Prevalence : 0.9403
## Balanced Accuracy : 0.6510
##
## 'Positive' Class : no
##
The plots above show the results of the last execution of the project.
After evaluating the different models many times, I noticed that there wasn´t a big difference between the two algorithms, although CatBoost got better results most of the times.
Also, both algorithms worked better in the categorical dataset, then in the numerical and last one in the mix dataset.
While some differences between CatBoost and Random Forest could be observed, this is just the beginning. We would have to test both algorithms on many datasets to be able to prove that one is actually better than the other. As a continuation of this project, different comparison metrics could be used as well as different datasets. Also, each algorithm could be evaluated without using caret, or modifying the parameters of each one.
Random Forest, Gradient Boosting
[Applied Predictive Modeling (Max Kuhn, Kjell Johnson)]