This is an R Markdown Notebook. When you execute code within the notebook, the results appear beneath the code.
Try executing this chunk by clicking the Run button within the chunk or by placing your cursor inside it and pressing Ctrl+Shift+Enter.
t <- proc.time()
# For manipulating the datasets
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(readr)
library(readxl)
# For plotting correlation matrix
library(ggcorrplot)
## Loading required package: ggplot2
# Machine Learning library
library(caret)
## Loading required package: lattice
library(catboost)
# For Multi-core processing support
library(parallel)
library(doParallel)
## Loading required package: foreach
## Loading required package: iterators
cl <- makePSOCKcluster(2)
registerDoParallel(cl)
#Numerical dataset
dataset_num <- read_excel("rice.xlsx")
#Categorical dataset
dataset_cat <- read.csv("mushrooms.csv")
#Mix dataset
dataset_mix <- read_excel("bank.xlsx")
dataset_cat %>% group_by(VEIL.TYPE) %>% summarise(total=n())
## `summarise()` ungrouping output (override with `.groups` argument)
#Eliminate VEIL.TYPE since it only has one value
dataset_cat <- dataset_cat %>% select(-VEIL.TYPE)
dataset_cat %>% group_by(STALK.ROOT) %>% summarise(total=n())
## `summarise()` ungrouping output (override with `.groups` argument)
#Eliminate STALK.ROOT since it has missing values
dataset_cat <- dataset_cat %>% select(-STALK.ROOT)
dataset_num$CLASS <- as.factor(dataset_num$CLASS)
dataset_cat <- mutate_if(dataset_cat, is.character, as.factor)
dataset_mix <- mutate_if(dataset_mix, is.character, as.factor)
#dataset <- dataset_num
dataset <- dataset_cat
#dataset <- dataset_mix
dataset
trainIndex <- createDataPartition(dataset$CLASS, p=0.80, list=FALSE)
data_train <- dataset[ trainIndex,]
data_test <- dataset[-trainIndex,]
valIndex <- createDataPartition(data_train$CLASS, p=0.15, list=FALSE)
val <- data_train[valIndex,]
train <- data_train[-valIndex,]
train
val
pool_train = catboost.load_pool( train %>% select(-CLASS),
label = as.numeric(as.factor(train$CLASS))-1)
pool_val = catboost.load_pool( val %>% select(-CLASS),
label = as.numeric(as.factor(val$CLASS))-1)
pool_test = catboost.load_pool( data_test %>% select(-CLASS),
label = as.numeric(as.factor(data_test$CLASS))-1)
fit_params <- list(iterations=100,
loss_function = 'Logloss',
task_type = 'CPU')
#fit_params <- list(iterations = 100,
# loss_function = 'MultiClass',
# border_count = 32,
# depth = 3,
# learning_rate = 0.03,
# l2_leaf_reg = 3.5,
# task_type = 'CPU',
# verbose = 10)
model <- catboost.train(pool_train,pool_val, params = fit_params)
## Learning rate set to 0.131597
## 0: learn: 0.3972028 test: 0.3958658 best: 0.3958658 (0) total: 152ms remaining: 15.1s
## 1: learn: 0.2500332 test: 0.2497346 best: 0.2497346 (1) total: 162ms remaining: 7.94s
## 2: learn: 0.1733431 test: 0.1737737 best: 0.1737737 (2) total: 172ms remaining: 5.55s
## 3: learn: 0.1320916 test: 0.1330643 best: 0.1330643 (3) total: 178ms remaining: 4.26s
## 4: learn: 0.0936201 test: 0.0957138 best: 0.0957138 (4) total: 183ms remaining: 3.48s
## 5: learn: 0.0565956 test: 0.0568127 best: 0.0568127 (5) total: 192ms remaining: 3s
## 6: learn: 0.0376931 test: 0.0366143 best: 0.0366143 (6) total: 200ms remaining: 2.66s
## 7: learn: 0.0276443 test: 0.0257732 best: 0.0257732 (7) total: 209ms remaining: 2.4s
## 8: learn: 0.0186257 test: 0.0170421 best: 0.0170421 (8) total: 218ms remaining: 2.21s
## 9: learn: 0.0130215 test: 0.0116901 best: 0.0116901 (9) total: 227ms remaining: 2.04s
## 10: learn: 0.0095200 test: 0.0083664 best: 0.0083664 (10) total: 236ms remaining: 1.91s
## 11: learn: 0.0073754 test: 0.0064188 best: 0.0064188 (11) total: 245ms remaining: 1.8s
## 12: learn: 0.0059308 test: 0.0052418 best: 0.0052418 (12) total: 254ms remaining: 1.7s
## 13: learn: 0.0049761 test: 0.0044365 best: 0.0044365 (13) total: 263ms remaining: 1.61s
## 14: learn: 0.0047103 test: 0.0042742 best: 0.0042742 (14) total: 271ms remaining: 1.54s
## 15: learn: 0.0044472 test: 0.0039874 best: 0.0039874 (15) total: 281ms remaining: 1.47s
## 16: learn: 0.0043662 test: 0.0039505 best: 0.0039505 (16) total: 289ms remaining: 1.41s
## 17: learn: 0.0041185 test: 0.0036808 best: 0.0036808 (17) total: 298ms remaining: 1.35s
## 18: learn: 0.0040305 test: 0.0036419 best: 0.0036419 (18) total: 307ms remaining: 1.31s
## 19: learn: 0.0039878 test: 0.0036447 best: 0.0036419 (18) total: 316ms remaining: 1.26s
## 20: learn: 0.0037991 test: 0.0035087 best: 0.0035087 (20) total: 324ms remaining: 1.22s
## 21: learn: 0.0035600 test: 0.0032865 best: 0.0032865 (21) total: 332ms remaining: 1.18s
## 22: learn: 0.0029276 test: 0.0026653 best: 0.0026653 (22) total: 341ms remaining: 1.14s
## 23: learn: 0.0025308 test: 0.0022808 best: 0.0022808 (23) total: 350ms remaining: 1.11s
## 24: learn: 0.0023820 test: 0.0021370 best: 0.0021370 (24) total: 358ms remaining: 1.07s
## 25: learn: 0.0022477 test: 0.0020662 best: 0.0020662 (25) total: 367ms remaining: 1.04s
## 26: learn: 0.0021661 test: 0.0020098 best: 0.0020098 (26) total: 376ms remaining: 1.01s
## 27: learn: 0.0020036 test: 0.0018391 best: 0.0018391 (27) total: 385ms remaining: 989ms
## 28: learn: 0.0019627 test: 0.0017751 best: 0.0017751 (28) total: 393ms remaining: 962ms
## 29: learn: 0.0019468 test: 0.0017735 best: 0.0017735 (29) total: 402ms remaining: 937ms
## 30: learn: 0.0017656 test: 0.0016019 best: 0.0016019 (30) total: 410ms remaining: 912ms
## 31: learn: 0.0016748 test: 0.0015032 best: 0.0015032 (31) total: 418ms remaining: 888ms
## 32: learn: 0.0016425 test: 0.0014779 best: 0.0014779 (32) total: 428ms remaining: 868ms
## 33: learn: 0.0015599 test: 0.0014095 best: 0.0014095 (33) total: 438ms remaining: 851ms
## 34: learn: 0.0015599 test: 0.0014095 best: 0.0014095 (34) total: 449ms remaining: 834ms
## 35: learn: 0.0015599 test: 0.0014095 best: 0.0014095 (35) total: 456ms remaining: 811ms
## 36: learn: 0.0015222 test: 0.0013836 best: 0.0013836 (36) total: 465ms remaining: 791ms
## 37: learn: 0.0015222 test: 0.0013836 best: 0.0013836 (37) total: 473ms remaining: 771ms
## 38: learn: 0.0014578 test: 0.0013413 best: 0.0013413 (38) total: 482ms remaining: 754ms
## 39: learn: 0.0014577 test: 0.0013413 best: 0.0013413 (39) total: 489ms remaining: 734ms
## 40: learn: 0.0013903 test: 0.0012796 best: 0.0012796 (40) total: 498ms remaining: 717ms
## 41: learn: 0.0013147 test: 0.0011977 best: 0.0011977 (41) total: 507ms remaining: 700ms
## 42: learn: 0.0012023 test: 0.0011058 best: 0.0011058 (42) total: 516ms remaining: 684ms
## 43: learn: 0.0011501 test: 0.0010883 best: 0.0010883 (43) total: 524ms remaining: 667ms
## 44: learn: 0.0011501 test: 0.0010883 best: 0.0010883 (44) total: 532ms remaining: 650ms
## 45: learn: 0.0011501 test: 0.0010883 best: 0.0010883 (44) total: 540ms remaining: 634ms
## 46: learn: 0.0011143 test: 0.0010351 best: 0.0010351 (46) total: 549ms remaining: 619ms
## 47: learn: 0.0011143 test: 0.0010351 best: 0.0010351 (47) total: 557ms remaining: 604ms
## 48: learn: 0.0011082 test: 0.0010351 best: 0.0010351 (47) total: 566ms remaining: 589ms
## 49: learn: 0.0011082 test: 0.0010351 best: 0.0010351 (47) total: 575ms remaining: 575ms
## 50: learn: 0.0010930 test: 0.0010470 best: 0.0010351 (47) total: 583ms remaining: 560ms
## 51: learn: 0.0010710 test: 0.0010273 best: 0.0010273 (51) total: 592ms remaining: 546ms
## 52: learn: 0.0010710 test: 0.0010273 best: 0.0010273 (51) total: 603ms remaining: 534ms
## 53: learn: 0.0010710 test: 0.0010273 best: 0.0010273 (51) total: 615ms remaining: 524ms
## 54: learn: 0.0010710 test: 0.0010273 best: 0.0010273 (51) total: 624ms remaining: 510ms
## 55: learn: 0.0010710 test: 0.0010273 best: 0.0010273 (55) total: 633ms remaining: 498ms
## 56: learn: 0.0010710 test: 0.0010273 best: 0.0010273 (55) total: 642ms remaining: 484ms
## 57: learn: 0.0010710 test: 0.0010273 best: 0.0010273 (55) total: 651ms remaining: 471ms
## 58: learn: 0.0010153 test: 0.0009700 best: 0.0009700 (58) total: 660ms remaining: 459ms
## 59: learn: 0.0009510 test: 0.0009249 best: 0.0009249 (59) total: 669ms remaining: 446ms
## 60: learn: 0.0009510 test: 0.0009249 best: 0.0009249 (59) total: 677ms remaining: 433ms
## 61: learn: 0.0009510 test: 0.0009249 best: 0.0009249 (61) total: 685ms remaining: 420ms
## 62: learn: 0.0009509 test: 0.0009246 best: 0.0009246 (62) total: 694ms remaining: 408ms
## 63: learn: 0.0009509 test: 0.0009246 best: 0.0009246 (63) total: 702ms remaining: 395ms
## 64: learn: 0.0009508 test: 0.0009246 best: 0.0009246 (64) total: 711ms remaining: 383ms
## 65: learn: 0.0009508 test: 0.0009246 best: 0.0009246 (65) total: 719ms remaining: 370ms
## 66: learn: 0.0009508 test: 0.0009246 best: 0.0009246 (66) total: 728ms remaining: 358ms
## 67: learn: 0.0009508 test: 0.0009246 best: 0.0009246 (67) total: 736ms remaining: 346ms
## 68: learn: 0.0009508 test: 0.0009246 best: 0.0009246 (67) total: 746ms remaining: 335ms
## 69: learn: 0.0009508 test: 0.0009246 best: 0.0009246 (67) total: 754ms remaining: 323ms
## 70: learn: 0.0009293 test: 0.0009126 best: 0.0009126 (70) total: 766ms remaining: 313ms
## 71: learn: 0.0009142 test: 0.0009183 best: 0.0009126 (70) total: 775ms remaining: 301ms
## 72: learn: 0.0009142 test: 0.0009183 best: 0.0009126 (70) total: 783ms remaining: 290ms
## 73: learn: 0.0009142 test: 0.0009183 best: 0.0009126 (70) total: 793ms remaining: 278ms
## 74: learn: 0.0009142 test: 0.0009183 best: 0.0009126 (70) total: 801ms remaining: 267ms
## 75: learn: 0.0009142 test: 0.0009182 best: 0.0009126 (70) total: 810ms remaining: 256ms
## 76: learn: 0.0009142 test: 0.0009182 best: 0.0009126 (70) total: 818ms remaining: 244ms
## 77: learn: 0.0009142 test: 0.0009183 best: 0.0009126 (70) total: 826ms remaining: 233ms
## 78: learn: 0.0009141 test: 0.0009182 best: 0.0009126 (70) total: 834ms remaining: 222ms
## 79: learn: 0.0009141 test: 0.0009182 best: 0.0009126 (70) total: 843ms remaining: 211ms
## 80: learn: 0.0009141 test: 0.0009183 best: 0.0009126 (70) total: 851ms remaining: 200ms
## 81: learn: 0.0009141 test: 0.0009182 best: 0.0009126 (70) total: 860ms remaining: 189ms
## 82: learn: 0.0009141 test: 0.0009182 best: 0.0009126 (70) total: 868ms remaining: 178ms
## 83: learn: 0.0009141 test: 0.0009182 best: 0.0009126 (70) total: 877ms remaining: 167ms
## 84: learn: 0.0009141 test: 0.0009182 best: 0.0009126 (70) total: 886ms remaining: 156ms
## 85: learn: 0.0009140 test: 0.0009182 best: 0.0009126 (70) total: 894ms remaining: 146ms
## 86: learn: 0.0009140 test: 0.0009182 best: 0.0009126 (70) total: 902ms remaining: 135ms
## 87: learn: 0.0009139 test: 0.0009183 best: 0.0009126 (70) total: 911ms remaining: 124ms
## 88: learn: 0.0009139 test: 0.0009183 best: 0.0009126 (70) total: 919ms remaining: 114ms
## 89: learn: 0.0009139 test: 0.0009183 best: 0.0009126 (70) total: 929ms remaining: 103ms
## 90: learn: 0.0009138 test: 0.0009183 best: 0.0009126 (70) total: 939ms remaining: 92.9ms
## 91: learn: 0.0009138 test: 0.0009183 best: 0.0009126 (70) total: 950ms remaining: 82.7ms
## 92: learn: 0.0009138 test: 0.0009183 best: 0.0009126 (70) total: 962ms remaining: 72.4ms
## 93: learn: 0.0009138 test: 0.0009183 best: 0.0009126 (70) total: 971ms remaining: 62ms
## 94: learn: 0.0009137 test: 0.0009183 best: 0.0009126 (70) total: 979ms remaining: 51.5ms
## 95: learn: 0.0009137 test: 0.0009183 best: 0.0009126 (70) total: 988ms remaining: 41.2ms
## 96: learn: 0.0009137 test: 0.0009183 best: 0.0009126 (70) total: 996ms remaining: 30.8ms
## 97: learn: 0.0009137 test: 0.0009183 best: 0.0009126 (70) total: 1.01s remaining: 20.5ms
## 98: learn: 0.0009135 test: 0.0009184 best: 0.0009126 (70) total: 1.01s remaining: 10.2ms
## 99: learn: 0.0009135 test: 0.0009184 best: 0.0009126 (70) total: 1.02s remaining: 0us
##
## bestTest = 0.0009126144348
## bestIteration = 70
##
## Shrink model to first 71 iterations.
test_prediction <-
catboost.predict(model,
pool_test,
prediction_type='Class',
#verbose = T,
#thread_count=2
)
cm <-
caret::confusionMatrix(as.factor(test_prediction),
as.factor(
as.numeric(
as.factor(data_test$CLASS))-1),
mode = 'everything')
cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 841 0
## 1 0 783
##
## Accuracy : 1
## 95% CI : (0.9977, 1)
## No Information Rate : 0.5179
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Sensitivity : 1.0000
## Specificity : 1.0000
## Pos Pred Value : 1.0000
## Neg Pred Value : 1.0000
## Precision : 1.0000
## Recall : 1.0000
## F1 : 1.0000
## Prevalence : 0.5179
## Detection Rate : 0.5179
## Detection Prevalence : 0.5179
## Balanced Accuracy : 1.0000
##
## 'Positive' Class : 0
##
fitControl <- trainControl(method="repeatedcv",
repeats = 2,
number = 5,
returnResamp = 'final',
savePredictions = 'final',
verboseIter = T,
allowParallel = T)
train_formula<-formula(CLASS~.)
rfFitupsam<- train(train_formula,
data = data_train,
method = "rf",
#tuneLength = 9,
#tuneGrid = svmGrid,
#preProcess=c("scale","center"),
#metric="ROC",
#weights = model_weights,
trControl = fitControl)
## Aggregating results
## Selecting tuning parameters
## Fitting mtry = 46 on full training set
rfFitupsam
## Random Forest
##
## 6500 samples
## 20 predictor
## 2 classes: 'e', 'p'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 2 times)
## Summary of sample sizes: 5200, 5199, 5200, 5201, 5200, 5199, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.9522274 0.9039895
## 46 1.0000000 1.0000000
## 91 0.9998461 0.9996918
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 46.
importance <- varImp(rfFitupsam, scale=FALSE)
plot(importance)
predsrfprobsamp=predict(rfFitupsam,data_test)
confusionMatrix(predsrfprobsamp,as.factor(data_test$CLASS))
## Confusion Matrix and Statistics
##
## Reference
## Prediction e p
## e 841 0
## p 0 783
##
## Accuracy : 1
## 95% CI : (0.9977, 1)
## No Information Rate : 0.5179
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Sensitivity : 1.0000
## Specificity : 1.0000
## Pos Pred Value : 1.0000
## Neg Pred Value : 1.0000
## Prevalence : 0.5179
## Detection Rate : 0.5179
## Detection Prevalence : 0.5179
## Balanced Accuracy : 1.0000
##
## 'Positive' Class : e
##
stopCluster(cl)
proc.time()-t
## user system elapsed
## 20.66 0.66 248.09