1. Librerías

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.2.1     ✔ readr     2.2.0
## ✔ forcats   1.0.1     ✔ stringr   1.6.0
## ✔ ggplot2   4.0.3     ✔ tibble    3.3.1
## ✔ lubridate 1.9.5     ✔ tidyr     1.3.2
## ✔ purrr     1.2.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
## Loading required package: lattice
## 
## Attaching package: 'caret'
## 
## The following object is masked from 'package:purrr':
## 
##     lift
library(nnet)
library(NeuralNetTools)
library(corrplot)
## corrplot 0.95 loaded
library(DataExplorer)

2. Carga de datos

data <- iris

3. EDA

glimpse(data)
## Rows: 150
## Columns: 5
## $ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.…
## $ Sepal.Width  <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.…
## $ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.…
## $ Petal.Width  <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.…
## $ Species      <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa, s…
summary(data)
##   Sepal.Length    Sepal.Width     Petal.Length    Petal.Width   
##  Min.   :4.300   Min.   :2.000   Min.   :1.000   Min.   :0.100  
##  1st Qu.:5.100   1st Qu.:2.800   1st Qu.:1.600   1st Qu.:0.300  
##  Median :5.800   Median :3.000   Median :4.350   Median :1.300  
##  Mean   :5.843   Mean   :3.057   Mean   :3.758   Mean   :1.199  
##  3rd Qu.:6.400   3rd Qu.:3.300   3rd Qu.:5.100   3rd Qu.:1.800  
##  Max.   :7.900   Max.   :4.400   Max.   :6.900   Max.   :2.500  
##        Species  
##  setosa    :50  
##  versicolor:50  
##  virginica :50  
##                 
##                 
## 
plot_histogram(data %>% select(-Species))

corrplot(cor(data[,1:4]), method = "color")

4. Preprocesamiento

data_scaled <- data %>%
  mutate(Species = as.factor(Species))

preproc <- preProcess(data_scaled[,1:4], method = c("center","scale"))
data_scaled[,1:4] <- predict(preproc, data_scaled[,1:4])

5. Split train/test

set.seed(123)
trainIndex <- createDataPartition(data_scaled$Species, p = 0.7, list = FALSE)

train <- data_scaled[trainIndex,]
test  <- data_scaled[-trainIndex,]

6. Modelo base

set.seed(123)
model_nn <- nnet(Species ~ ., data = train, size = 5, maxit = 300, trace = FALSE)

7. Predicciones

pred <- predict(model_nn, test, type = "class")

pred <- factor(pred, levels = levels(train$Species))

confusionMatrix(pred, test$Species)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         14          0         0
##   versicolor      1         15         3
##   virginica       0          0        12
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9111          
##                  95% CI : (0.7878, 0.9752)
##     No Information Rate : 0.3333          
##     P-Value [Acc > NIR] : 8.467e-16       
##                                           
##                   Kappa : 0.8667          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 0.9333            1.0000           0.8000
## Specificity                 1.0000            0.8667           1.0000
## Pos Pred Value              1.0000            0.7895           1.0000
## Neg Pred Value              0.9677            1.0000           0.9091
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3111            0.3333           0.2667
## Detection Prevalence        0.3111            0.4222           0.2667
## Balanced Accuracy           0.9667            0.9333           0.9000

8. Tunig de hiperparámetros

grid <- expand.grid(size = c(3,5,7), decay = c(0.01, 0.1, 0.5))

control <- trainControl(method = "cv", number = 5)

set.seed(123)
model_tuned <- train(
  Species ~ ., 
  data = train,
  method = "nnet",
  tuneGrid = grid,
  trControl = control,
  trace = FALSE,
  maxit = 300
)

model_tuned
## Neural Network 
## 
## 105 samples
##   4 predictor
##   3 classes: 'setosa', 'versicolor', 'virginica' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 84, 84, 84, 84, 84 
## Resampling results across tuning parameters:
## 
##   size  decay  Accuracy   Kappa    
##   3     0.01   0.9523810  0.9285714
##   3     0.10   0.9619048  0.9428571
##   3     0.50   0.9619048  0.9428571
##   5     0.01   0.9523810  0.9285714
##   5     0.10   0.9619048  0.9428571
##   5     0.50   0.9619048  0.9428571
##   7     0.01   0.9523810  0.9285714
##   7     0.10   0.9619048  0.9428571
##   7     0.50   0.9619048  0.9428571
## 
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were size = 3 and decay = 0.5.

9. Evaluación modelo optimizado

pred2 <- predict(model_tuned, test)
confusionMatrix(pred2, test$Species)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         15          0         0
##   versicolor      0         14         2
##   virginica       0          1        13
## 
## Overall Statistics
##                                          
##                Accuracy : 0.9333         
##                  95% CI : (0.8173, 0.986)
##     No Information Rate : 0.3333         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.9            
##                                          
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.9333           0.8667
## Specificity                 1.0000            0.9333           0.9667
## Pos Pred Value              1.0000            0.8750           0.9286
## Neg Pred Value              1.0000            0.9655           0.9355
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.3111           0.2889
## Detection Prevalence        0.3333            0.3556           0.3111
## Balanced Accuracy           1.0000            0.9333           0.9167

10. Importancia de variables

varImp(model_tuned)
## nnet variable importance
## 
##   variables are sorted by maximum importance across the classes
##                Overall    setosa versicolor virginica
## Petal.Width  1.000e+02 1.000e+02     100.00    100.00
## Petal.Length 8.961e+01 8.961e+01      89.61     89.61
## Sepal.Width  1.759e+01 1.759e+01      17.59     17.59
## Sepal.Length 6.677e-15 2.003e-14       0.00      0.00

11. Visualización red neuronal

plotnet(model_nn)

12. Análisis de errores

errors <- test %>%
  mutate(pred = pred2) %>%
  filter(pred != Species)

head(errors)
##   Sepal.Length Sepal.Width Petal.Length Petal.Width    Species       pred
## 1    1.0345390  -0.1315388    0.7035638   0.6568380 versicolor  virginica
## 2    0.1891958  -1.9669641    0.7035638   0.3944526  virginica versicolor
## 3    0.5514857  -0.5903951    0.7602115   0.3944526  virginica versicolor

13. Probabilidades

prob <- predict(model_tuned, test, type = "prob")
head(prob)
##       setosa versicolor  virginica
## 1  0.9248054 0.06365334 0.01154127
## 2  0.9024100 0.08327805 0.01431192
## 6  0.9262915 0.06229885 0.01140968
## 16 0.9351209 0.05466058 0.01021848
## 18 0.9224031 0.06572697 0.01186992
## 20 0.9301218 0.05900211 0.01087607

14. Curva ROC (One vs Rest)

library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
roc_obj <- roc(as.numeric(test$Species), prob[,1])
## Warning in roc.default(as.numeric(test$Species), prob[, 1]): 'response' has
## more than two levels. Consider setting 'levels' explicitly or using
## 'multiclass.roc' instead
## Setting levels: control = 1, case = 2
## Setting direction: controls > cases
plot(roc_obj)

auc(roc_obj)
## Area under the curve: 1

15. Overfitting check

train_acc <- mean(predict(model_tuned, train) == train$Species)
test_acc <- mean(pred2 == test$Species)

train_acc
## [1] 0.9714286
test_acc
## [1] 0.9333333

16. Insights

cat("El modelo muestra buena capacidad predictiva, con validación cruzada para evitar overfitting.")
## El modelo muestra buena capacidad predictiva, con validación cruzada para evitar overfitting.

17. Recomendaciones

cat("Se recomienda probar redes profundas con keras para problemas más complejos.")
## Se recomienda probar redes profundas con keras para problemas más complejos.