1. Librerías
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.2.1 ✔ readr 2.2.0
## ✔ forcats 1.0.1 ✔ stringr 1.6.0
## ✔ ggplot2 4.0.3 ✔ tibble 3.3.1
## ✔ lubridate 1.9.5 ✔ tidyr 1.3.2
## ✔ purrr 1.2.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
##
## The following object is masked from 'package:purrr':
##
## lift
library(nnet)
library(NeuralNetTools)
library(corrplot)
## corrplot 0.95 loaded
library(DataExplorer)
2. Carga de datos
data <- iris
3. EDA
glimpse(data)
## Rows: 150
## Columns: 5
## $ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.…
## $ Sepal.Width <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.…
## $ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.…
## $ Petal.Width <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.…
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa, s…
summary(data)
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Min. :4.300 Min. :2.000 Min. :1.000 Min. :0.100
## 1st Qu.:5.100 1st Qu.:2.800 1st Qu.:1.600 1st Qu.:0.300
## Median :5.800 Median :3.000 Median :4.350 Median :1.300
## Mean :5.843 Mean :3.057 Mean :3.758 Mean :1.199
## 3rd Qu.:6.400 3rd Qu.:3.300 3rd Qu.:5.100 3rd Qu.:1.800
## Max. :7.900 Max. :4.400 Max. :6.900 Max. :2.500
## Species
## setosa :50
## versicolor:50
## virginica :50
##
##
##
plot_histogram(data %>% select(-Species))

corrplot(cor(data[,1:4]), method = "color")

4. Preprocesamiento
data_scaled <- data %>%
mutate(Species = as.factor(Species))
preproc <- preProcess(data_scaled[,1:4], method = c("center","scale"))
data_scaled[,1:4] <- predict(preproc, data_scaled[,1:4])
5. Split train/test
set.seed(123)
trainIndex <- createDataPartition(data_scaled$Species, p = 0.7, list = FALSE)
train <- data_scaled[trainIndex,]
test <- data_scaled[-trainIndex,]
6. Modelo base
set.seed(123)
model_nn <- nnet(Species ~ ., data = train, size = 5, maxit = 300, trace = FALSE)
7. Predicciones
pred <- predict(model_nn, test, type = "class")
pred <- factor(pred, levels = levels(train$Species))
confusionMatrix(pred, test$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 14 0 0
## versicolor 1 15 3
## virginica 0 0 12
##
## Overall Statistics
##
## Accuracy : 0.9111
## 95% CI : (0.7878, 0.9752)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 8.467e-16
##
## Kappa : 0.8667
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 0.9333 1.0000 0.8000
## Specificity 1.0000 0.8667 1.0000
## Pos Pred Value 1.0000 0.7895 1.0000
## Neg Pred Value 0.9677 1.0000 0.9091
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3111 0.3333 0.2667
## Detection Prevalence 0.3111 0.4222 0.2667
## Balanced Accuracy 0.9667 0.9333 0.9000
8. Tunig de hiperparámetros
grid <- expand.grid(size = c(3,5,7), decay = c(0.01, 0.1, 0.5))
control <- trainControl(method = "cv", number = 5)
set.seed(123)
model_tuned <- train(
Species ~ .,
data = train,
method = "nnet",
tuneGrid = grid,
trControl = control,
trace = FALSE,
maxit = 300
)
model_tuned
## Neural Network
##
## 105 samples
## 4 predictor
## 3 classes: 'setosa', 'versicolor', 'virginica'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 84, 84, 84, 84, 84
## Resampling results across tuning parameters:
##
## size decay Accuracy Kappa
## 3 0.01 0.9523810 0.9285714
## 3 0.10 0.9619048 0.9428571
## 3 0.50 0.9619048 0.9428571
## 5 0.01 0.9523810 0.9285714
## 5 0.10 0.9619048 0.9428571
## 5 0.50 0.9619048 0.9428571
## 7 0.01 0.9523810 0.9285714
## 7 0.10 0.9619048 0.9428571
## 7 0.50 0.9619048 0.9428571
##
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were size = 3 and decay = 0.5.
9. Evaluación modelo optimizado
pred2 <- predict(model_tuned, test)
confusionMatrix(pred2, test$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 15 0 0
## versicolor 0 14 2
## virginica 0 1 13
##
## Overall Statistics
##
## Accuracy : 0.9333
## 95% CI : (0.8173, 0.986)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.9333 0.8667
## Specificity 1.0000 0.9333 0.9667
## Pos Pred Value 1.0000 0.8750 0.9286
## Neg Pred Value 1.0000 0.9655 0.9355
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.3111 0.2889
## Detection Prevalence 0.3333 0.3556 0.3111
## Balanced Accuracy 1.0000 0.9333 0.9167
10. Importancia de variables
varImp(model_tuned)
## nnet variable importance
##
## variables are sorted by maximum importance across the classes
## Overall setosa versicolor virginica
## Petal.Width 1.000e+02 1.000e+02 100.00 100.00
## Petal.Length 8.961e+01 8.961e+01 89.61 89.61
## Sepal.Width 1.759e+01 1.759e+01 17.59 17.59
## Sepal.Length 6.677e-15 2.003e-14 0.00 0.00
11. Visualización red neuronal
plotnet(model_nn)

12. Análisis de errores
errors <- test %>%
mutate(pred = pred2) %>%
filter(pred != Species)
head(errors)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species pred
## 1 1.0345390 -0.1315388 0.7035638 0.6568380 versicolor virginica
## 2 0.1891958 -1.9669641 0.7035638 0.3944526 virginica versicolor
## 3 0.5514857 -0.5903951 0.7602115 0.3944526 virginica versicolor
13. Probabilidades
prob <- predict(model_tuned, test, type = "prob")
head(prob)
## setosa versicolor virginica
## 1 0.9248054 0.06365334 0.01154127
## 2 0.9024100 0.08327805 0.01431192
## 6 0.9262915 0.06229885 0.01140968
## 16 0.9351209 0.05466058 0.01021848
## 18 0.9224031 0.06572697 0.01186992
## 20 0.9301218 0.05900211 0.01087607
14. Curva ROC (One vs Rest)
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
roc_obj <- roc(as.numeric(test$Species), prob[,1])
## Warning in roc.default(as.numeric(test$Species), prob[, 1]): 'response' has
## more than two levels. Consider setting 'levels' explicitly or using
## 'multiclass.roc' instead
## Setting levels: control = 1, case = 2
## Setting direction: controls > cases
plot(roc_obj)

auc(roc_obj)
## Area under the curve: 1
15. Overfitting check
train_acc <- mean(predict(model_tuned, train) == train$Species)
test_acc <- mean(pred2 == test$Species)
train_acc
## [1] 0.9714286
test_acc
## [1] 0.9333333
16. Insights
cat("El modelo muestra buena capacidad predictiva, con validación cruzada para evitar overfitting.")
## El modelo muestra buena capacidad predictiva, con validación cruzada para evitar overfitting.
17. Recomendaciones
cat("Se recomienda probar redes profundas con keras para problemas más complejos.")
## Se recomienda probar redes profundas con keras para problemas más complejos.