Para instalar catboost, no se utilza install.packages()
sino que hay que instalarlo via devtools
. Para esto primero hay que instalar devtools
via install.packages()
install.packages('devtools')
devtools::install_url('https://github.com/catboost/catboost/releases/download/v0.24.1/catboost-R-Linux-0.24.1.tgz', INSTALL_opts = c("--no-multiarch"))
library(catboost)
library(randomForest)
library(caret)
library(dplyr)
data(iris)
iris_local<-iris %>% filter(Species != "virginica")
train_idx<-sample(1:nrow(iris_local),(nrow(iris_local) * 70) / 100.0)
train<-iris_local[train_idx,]
test<-iris_local[-train_idx,]
val_idx<-sample(1:nrow(train),(nrow(train) * 15) / 100.0)
val<-train[val_idx,]
train <-train[-val_idx]
train
val
train_x<-train[,1:4]
train_y<-as.numeric(as.factor(train$Species))-1
test_x<-test[,1:4]
test_y<-as.numeric(as.factor(test$Species))-1
model_rf<-randomForest(x=train_x,y=as.factor(train_y))
test_prediction <-predict(model_rf,
test_x)
cm <-
caret::confusionMatrix(as.factor(test_prediction),
as.factor(
as.numeric(
as.factor(test$Species))-1),
mode = 'everything')
cm
Confusion Matrix and Statistics
Reference
Prediction 0 1
0 16 0
1 0 14
Accuracy : 1
95% CI : (0.8843, 1)
No Information Rate : 0.5333
P-Value [Acc > NIR] : 6.456e-09
Kappa : 1
Mcnemar's Test P-Value : NA
Sensitivity : 1.0000
Specificity : 1.0000
Pos Pred Value : 1.0000
Neg Pred Value : 1.0000
Precision : 1.0000
Recall : 1.0000
F1 : 1.0000
Prevalence : 0.5333
Detection Rate : 0.5333
Detection Prevalence : 0.5333
Balanced Accuracy : 1.0000
'Positive' Class : 0
pool_train = catboost.load_pool( train[,1:4],
label = as.numeric(as.factor(train$Species))-1)
pool_val = catboost.load_pool( val[,1:4],
label = as.numeric(as.factor(val$Species))-1)
pool_test = catboost.load_pool( test[,1:4],
label = as.numeric(as.factor(test$Species))-1)
head(pool_test)
[,1] [,2] [,3] [,4] [,5] [,6]
[1,] 0 1 5.1 3.5 1.4 0.2
[2,] 0 1 4.7 3.2 1.3 0.2
[3,] 0 1 5.4 3.9 1.7 0.4
[4,] 0 1 4.6 3.4 1.4 0.3
[5,] 0 1 5.0 3.4 1.5 0.2
[6,] 0 1 4.9 3.1 1.5 0.1
[7,] 0 1 5.1 3.5 1.4 0.3
[8,] 0 1 5.2 3.5 1.5 0.2
[9,] 0 1 4.8 3.1 1.6 0.2
[10,] 0 1 5.0 3.2 1.2 0.2
head(pool_train)
[,1] [,2] [,3] [,4] [,5] [,6]
[1,] 1 1 6.0 2.2 4.0 1.0
[2,] 1 1 6.3 2.5 4.9 1.5
[3,] 1 1 5.6 3.0 4.5 1.5
[4,] 0 1 5.2 4.1 1.5 0.1
[5,] 1 1 5.6 2.7 4.2 1.3
[6,] 0 1 5.1 3.3 1.7 0.5
[7,] 0 1 5.4 3.7 1.5 0.2
[8,] 1 1 5.5 2.4 3.8 1.1
[9,] 0 1 4.8 3.4 1.6 0.2
[10,] 0 1 5.4 3.4 1.5 0.4
EL numero de veces que va a iterar el algoritmo, y la funcion de Loss. La cual es la que se va a utilizar para ver el rendimiento del algoritmo en cada iteración. Similar a lo que ocurre en las redes neuronales.
fit_params <- list(iterations=100,
loss_function = 'Logloss',
task_type = 'CPU')
#fit_params <- list(iterations = 100,
# loss_function = 'MultiClass',
# border_count = 32,
# depth = 3,
# learning_rate = 0.03,
# l2_leaf_reg = 3.5,
# task_type = 'CPU',
# verbose = 10)
Para entrenar el modelo catboost, le pasamos el conjunto de train y el conjunto de validation. Catboost ira iterando (hasta 100 veces) y se va a detener en cuanto los resultados en el conjunto de validación comiencen a empeorar. De esta forma se evita el problema de overfitting.
model <- catboost.train(pool_train,pool_val, params = fit_params)
Learning rate set to 0.044735
0: learn: 0.6677680 test: 0.6719587 best: 0.6719587 (0) total: 439us remaining: 43.5ms
1: learn: 0.6440598 test: 0.6461715 best: 0.6461715 (1) total: 708us remaining: 34.7ms
2: learn: 0.6262159 test: 0.6279214 best: 0.6279214 (2) total: 1.8ms remaining: 58.1ms
3: learn: 0.6052709 test: 0.6088039 best: 0.6088039 (3) total: 2.35ms remaining: 56.5ms
4: learn: 0.5842307 test: 0.5891116 best: 0.5891116 (4) total: 2.58ms remaining: 49.1ms
5: learn: 0.5628501 test: 0.5701187 best: 0.5701187 (5) total: 3.2ms remaining: 50.1ms
6: learn: 0.5505292 test: 0.5566190 best: 0.5566190 (6) total: 3.73ms remaining: 49.6ms
7: learn: 0.5337013 test: 0.5357614 best: 0.5357614 (7) total: 4.74ms remaining: 54.5ms
8: learn: 0.5151230 test: 0.5167764 best: 0.5167764 (8) total: 5.19ms remaining: 52.4ms
9: learn: 0.4932531 test: 0.4949896 best: 0.4949896 (9) total: 5.5ms remaining: 49.6ms
10: learn: 0.4826190 test: 0.4835703 best: 0.4835703 (10) total: 6.04ms remaining: 48.9ms
11: learn: 0.4711824 test: 0.4726001 best: 0.4726001 (11) total: 7.11ms remaining: 52.2ms
12: learn: 0.4610435 test: 0.4608775 best: 0.4608775 (12) total: 9.72ms remaining: 65.1ms
13: learn: 0.4447077 test: 0.4436523 best: 0.4436523 (13) total: 10.1ms remaining: 62ms
14: learn: 0.4312251 test: 0.4320395 best: 0.4320395 (14) total: 10.7ms remaining: 60.7ms
15: learn: 0.4198664 test: 0.4205179 best: 0.4205179 (15) total: 10.9ms remaining: 57.4ms
16: learn: 0.4078940 test: 0.4070262 best: 0.4070262 (16) total: 11.2ms remaining: 54.6ms
17: learn: 0.3974511 test: 0.3962852 best: 0.3962852 (17) total: 11.5ms remaining: 52.6ms
18: learn: 0.3870156 test: 0.3867940 best: 0.3867940 (18) total: 12ms remaining: 51.1ms
19: learn: 0.3791088 test: 0.3790127 best: 0.3790127 (19) total: 12.6ms remaining: 50.5ms
20: learn: 0.3699627 test: 0.3685060 best: 0.3685060 (20) total: 12.9ms remaining: 48.4ms
21: learn: 0.3625590 test: 0.3605576 best: 0.3605576 (21) total: 13.1ms remaining: 46.4ms
22: learn: 0.3533490 test: 0.3495862 best: 0.3495862 (22) total: 13.3ms remaining: 44.6ms
23: learn: 0.3446101 test: 0.3422866 best: 0.3422866 (23) total: 15.1ms remaining: 47.9ms
24: learn: 0.3368033 test: 0.3344620 best: 0.3344620 (24) total: 15.5ms remaining: 46.6ms
25: learn: 0.3243744 test: 0.3222113 best: 0.3222113 (25) total: 15.7ms remaining: 44.5ms
26: learn: 0.3161011 test: 0.3128968 best: 0.3128968 (26) total: 16ms remaining: 43.3ms
27: learn: 0.3079743 test: 0.3055185 best: 0.3055185 (27) total: 16.5ms remaining: 42.4ms
28: learn: 0.3012161 test: 0.2982751 best: 0.2982751 (28) total: 17.1ms remaining: 42ms
29: learn: 0.2947653 test: 0.2919401 best: 0.2919401 (29) total: 17.7ms remaining: 41.3ms
30: learn: 0.2889531 test: 0.2858075 best: 0.2858075 (30) total: 18.2ms remaining: 40.6ms
31: learn: 0.2828909 test: 0.2790942 best: 0.2790942 (31) total: 18.8ms remaining: 40ms
32: learn: 0.2775582 test: 0.2723133 best: 0.2723133 (32) total: 19.7ms remaining: 39.9ms
33: learn: 0.2714516 test: 0.2655629 best: 0.2655629 (33) total: 20.4ms remaining: 39.6ms
34: learn: 0.2659471 test: 0.2591049 best: 0.2591049 (34) total: 20.8ms remaining: 38.7ms
35: learn: 0.2586950 test: 0.2529405 best: 0.2529405 (35) total: 21.3ms remaining: 37.8ms
36: learn: 0.2536740 test: 0.2494636 best: 0.2494636 (36) total: 22ms remaining: 37.4ms
37: learn: 0.2491356 test: 0.2446644 best: 0.2446644 (37) total: 22.9ms remaining: 37.3ms
38: learn: 0.2433963 test: 0.2400398 best: 0.2400398 (38) total: 23.6ms remaining: 36.8ms
39: learn: 0.2378033 test: 0.2343015 best: 0.2343015 (39) total: 24.3ms remaining: 36.4ms
40: learn: 0.2306063 test: 0.2282880 best: 0.2282880 (40) total: 24.6ms remaining: 35.3ms
41: learn: 0.2248869 test: 0.2236178 best: 0.2236178 (41) total: 25.2ms remaining: 34.8ms
42: learn: 0.2193865 test: 0.2174538 best: 0.2174538 (42) total: 25.6ms remaining: 33.9ms
43: learn: 0.2145509 test: 0.2128820 best: 0.2128820 (43) total: 26.2ms remaining: 33.3ms
44: learn: 0.2103889 test: 0.2083639 best: 0.2083639 (44) total: 26.9ms remaining: 32.8ms
45: learn: 0.2059360 test: 0.2031905 best: 0.2031905 (45) total: 27.7ms remaining: 32.6ms
46: learn: 0.2024548 test: 0.2008645 best: 0.2008645 (46) total: 28.4ms remaining: 32.1ms
47: learn: 0.1978823 test: 0.1957624 best: 0.1957624 (47) total: 28.6ms remaining: 31ms
48: learn: 0.1949683 test: 0.1923570 best: 0.1923570 (48) total: 29.3ms remaining: 30.5ms
49: learn: 0.1919014 test: 0.1886565 best: 0.1886565 (49) total: 29.9ms remaining: 29.9ms
50: learn: 0.1886655 test: 0.1857901 best: 0.1857901 (50) total: 30.4ms remaining: 29.2ms
51: learn: 0.1848620 test: 0.1814206 best: 0.1814206 (51) total: 30.9ms remaining: 28.5ms
52: learn: 0.1802670 test: 0.1766237 best: 0.1766237 (52) total: 31.3ms remaining: 27.8ms
53: learn: 0.1771283 test: 0.1741695 best: 0.1741695 (53) total: 31.7ms remaining: 27ms
54: learn: 0.1742215 test: 0.1717928 best: 0.1717928 (54) total: 32.3ms remaining: 26.4ms
55: learn: 0.1709966 test: 0.1699449 best: 0.1699449 (55) total: 33ms remaining: 25.9ms
56: learn: 0.1680527 test: 0.1673195 best: 0.1673195 (56) total: 33.6ms remaining: 25.3ms
57: learn: 0.1649909 test: 0.1644754 best: 0.1644754 (57) total: 34ms remaining: 24.6ms
58: learn: 0.1614578 test: 0.1616845 best: 0.1616845 (58) total: 34.5ms remaining: 24ms
59: learn: 0.1580744 test: 0.1580950 best: 0.1580950 (59) total: 34.7ms remaining: 23.1ms
60: learn: 0.1547989 test: 0.1550995 best: 0.1550995 (60) total: 35.3ms remaining: 22.6ms
61: learn: 0.1518242 test: 0.1519290 best: 0.1519290 (61) total: 36ms remaining: 22.1ms
62: learn: 0.1489520 test: 0.1486387 best: 0.1486387 (62) total: 36.6ms remaining: 21.5ms
63: learn: 0.1464015 test: 0.1460717 best: 0.1460717 (63) total: 37.1ms remaining: 20.9ms
64: learn: 0.1432026 test: 0.1425708 best: 0.1425708 (64) total: 37.5ms remaining: 20.2ms
65: learn: 0.1407988 test: 0.1397788 best: 0.1397788 (65) total: 37.9ms remaining: 19.5ms
66: learn: 0.1373769 test: 0.1361622 best: 0.1361622 (66) total: 38ms remaining: 18.7ms
67: learn: 0.1352371 test: 0.1339499 best: 0.1339499 (67) total: 38.6ms remaining: 18.2ms
68: learn: 0.1329267 test: 0.1321146 best: 0.1321146 (68) total: 39.2ms remaining: 17.6ms
69: learn: 0.1311242 test: 0.1298305 best: 0.1298305 (69) total: 39.5ms remaining: 16.9ms
70: learn: 0.1286417 test: 0.1275977 best: 0.1275977 (70) total: 39.8ms remaining: 16.3ms
71: learn: 0.1266342 test: 0.1258356 best: 0.1258356 (71) total: 40ms remaining: 15.6ms
72: learn: 0.1248747 test: 0.1238523 best: 0.1238523 (72) total: 40.7ms remaining: 15ms
73: learn: 0.1228734 test: 0.1220811 best: 0.1220811 (73) total: 40.9ms remaining: 14.4ms
74: learn: 0.1209351 test: 0.1200008 best: 0.1200008 (74) total: 41.5ms remaining: 13.8ms
75: learn: 0.1195315 test: 0.1184948 best: 0.1184948 (75) total: 41.9ms remaining: 13.2ms
76: learn: 0.1178860 test: 0.1166961 best: 0.1166961 (76) total: 42.4ms remaining: 12.7ms
77: learn: 0.1163729 test: 0.1154974 best: 0.1154974 (77) total: 43.2ms remaining: 12.2ms
78: learn: 0.1145512 test: 0.1137628 best: 0.1137628 (78) total: 43.4ms remaining: 11.5ms
79: learn: 0.1119710 test: 0.1111307 best: 0.1111307 (79) total: 43.5ms remaining: 10.9ms
80: learn: 0.1103215 test: 0.1095540 best: 0.1095540 (80) total: 43.7ms remaining: 10.3ms
81: learn: 0.1085648 test: 0.1073594 best: 0.1073594 (81) total: 44.3ms remaining: 9.72ms
82: learn: 0.1067429 test: 0.1056900 best: 0.1056900 (82) total: 44.8ms remaining: 9.18ms
83: learn: 0.1052000 test: 0.1046463 best: 0.1046463 (83) total: 45.2ms remaining: 8.61ms
84: learn: 0.1041010 test: 0.1033099 best: 0.1033099 (84) total: 45.9ms remaining: 8.1ms
85: learn: 0.1016863 test: 0.1009213 best: 0.1009213 (85) total: 46ms remaining: 7.49ms
86: learn: 0.1001383 test: 0.0990904 best: 0.0990904 (86) total: 46.6ms remaining: 6.97ms
87: learn: 0.0982364 test: 0.0971337 best: 0.0971337 (87) total: 46.8ms remaining: 6.38ms
88: learn: 0.0972124 test: 0.0959806 best: 0.0959806 (88) total: 47.3ms remaining: 5.85ms
89: learn: 0.0961530 test: 0.0946992 best: 0.0946992 (89) total: 47.7ms remaining: 5.3ms
90: learn: 0.0949806 test: 0.0939389 best: 0.0939389 (90) total: 48.2ms remaining: 4.77ms
91: learn: 0.0935288 test: 0.0925390 best: 0.0925390 (91) total: 48.6ms remaining: 4.22ms
92: learn: 0.0918799 test: 0.0907919 best: 0.0907919 (92) total: 48.7ms remaining: 3.67ms
93: learn: 0.0909273 test: 0.0896887 best: 0.0896887 (93) total: 49.3ms remaining: 3.15ms
94: learn: 0.0900152 test: 0.0885220 best: 0.0885220 (94) total: 49.5ms remaining: 2.61ms
95: learn: 0.0890811 test: 0.0877734 best: 0.0877734 (95) total: 49.7ms remaining: 2.07ms
96: learn: 0.0880719 test: 0.0866968 best: 0.0866968 (96) total: 50.5ms remaining: 1.56ms
97: learn: 0.0863268 test: 0.0850610 best: 0.0850610 (97) total: 50.9ms remaining: 1.04ms
98: learn: 0.0853987 test: 0.0839297 best: 0.0839297 (98) total: 51.3ms remaining: 518us
99: learn: 0.0842132 test: 0.0828672 best: 0.0828672 (99) total: 51.9ms remaining: 0us
bestTest = 0.08286718022
bestIteration = 99
test_prediction <-
catboost.predict(model,
pool_test,prediction_type='Class')
cm <-
caret::confusionMatrix(as.factor(test_prediction),
as.factor(
as.numeric(
as.factor(test$Species))-1),
mode = 'everything')
cm
Confusion Matrix and Statistics
Reference
Prediction 0 1
0 16 0
1 0 14
Accuracy : 1
95% CI : (0.8843, 1)
No Information Rate : 0.5333
P-Value [Acc > NIR] : 6.456e-09
Kappa : 1
Mcnemar's Test P-Value : NA
Sensitivity : 1.0000
Specificity : 1.0000
Pos Pred Value : 1.0000
Neg Pred Value : 1.0000
Precision : 1.0000
Recall : 1.0000
F1 : 1.0000
Prevalence : 0.5333
Detection Rate : 0.5333
Detection Prevalence : 0.5333
Balanced Accuracy : 1.0000
'Positive' Class : 0