Objetivo

Este tutorial recoge los pasos básicos para llevar a cabo un clasificador multietiqueta basado en una red neuronal artificial haciendo uso del paquete neuralnet de R. En este ejemplo realizaremos un clasificador de tipos de vino en función de 13 variables obtenidas del análisis químico de los mismos pudiendo ser clasificados como 1, 2 y 3.

Fuente de datos

La fuente de datos de este ejemplo la obtendremos de la UCI Machine Learning Repository, concretamente de la siguiente URL https://archive.ics.uci.edu/ml/machine-learning-database/ine/wine.data que contiene los análisis químicos de tres clases diferentes de vino.

if(!file.exists('myfile.csv')) # descargamos el archivo sólo si no se encuentra ya en nuestro directorio de trabajo
{
    url <- "https://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data"
    download.file(url, destfile = 'myfile.csv', method = 'curl')
}
data <- read.csv('myfile.csv', header = FALSE, sep=',')

Exploración

Empezamos echando un breve vistazo a los datos descargados y así tener una idea clara las características de los datos con los que estamos trabajando.

head(data)
##   V1    V2   V3   V4   V5  V6   V7   V8   V9  V10  V11  V12  V13  V14
## 1  1 14.23 1.71 2.43 15.6 127 2.80 3.06 0.28 2.29 5.64 1.04 3.92 1065
## 2  1 13.20 1.78 2.14 11.2 100 2.65 2.76 0.26 1.28 4.38 1.05 3.40 1050
## 3  1 13.16 2.36 2.67 18.6 101 2.80 3.24 0.30 2.81 5.68 1.03 3.17 1185
## 4  1 14.37 1.95 2.50 16.8 113 3.85 3.49 0.24 2.18 7.80 0.86 3.45 1480
## 5  1 13.24 2.59 2.87 21.0 118 2.80 2.69 0.39 1.82 4.32 1.04 2.93  735
## 6  1 14.20 1.76 2.45 15.2 112 3.27 3.39 0.34 1.97 6.75 1.05 2.85 1450
str(data)
## 'data.frame':    178 obs. of  14 variables:
##  $ V1 : int  1 1 1 1 1 1 1 1 1 1 ...
##  $ V2 : num  14.2 13.2 13.2 14.4 13.2 ...
##  $ V3 : num  1.71 1.78 2.36 1.95 2.59 1.76 1.87 2.15 1.64 1.35 ...
##  $ V4 : num  2.43 2.14 2.67 2.5 2.87 2.45 2.45 2.61 2.17 2.27 ...
##  $ V5 : num  15.6 11.2 18.6 16.8 21 15.2 14.6 17.6 14 16 ...
##  $ V6 : int  127 100 101 113 118 112 96 121 97 98 ...
##  $ V7 : num  2.8 2.65 2.8 3.85 2.8 3.27 2.5 2.6 2.8 2.98 ...
##  $ V8 : num  3.06 2.76 3.24 3.49 2.69 3.39 2.52 2.51 2.98 3.15 ...
##  $ V9 : num  0.28 0.26 0.3 0.24 0.39 0.34 0.3 0.31 0.29 0.22 ...
##  $ V10: num  2.29 1.28 2.81 2.18 1.82 1.97 1.98 1.25 1.98 1.85 ...
##  $ V11: num  5.64 4.38 5.68 7.8 4.32 6.75 5.25 5.05 5.2 7.22 ...
##  $ V12: num  1.04 1.05 1.03 0.86 1.04 1.05 1.02 1.06 1.08 1.01 ...
##  $ V13: num  3.92 3.4 3.17 3.45 2.93 2.85 3.58 3.58 2.85 3.55 ...
##  $ V14: int  1065 1050 1185 1480 735 1450 1290 1295 1045 1045 ...

Nuestro dataframe está compuesto por 178 observaciones de 14 variables representando la primera de ellas el tipo de vino como 1, 2 y 3. Dado que el dataframe no contiene cabecera, será necesario incluirla y asi facilitar el trabajo posterior de análisis. Para ayudarnos en la denominación de cada una de las variables es recomendable revisar el documento Data Set Description que contiene un breve descripción de todas las variables. Vamos a añadir las cabeceras al dataframe original. Conviene resaltar que los nombres de las variables deben ser lo suficientemente explícitos como para poder entender el contenido de las mismas a simple vista.

names(data) <- c('Wine','Alcohol','Malic_acid','Ash','Alcalinity_ash','Magnesium','Total_phenols','Flavanoids','Nonflavanoinds_phenols','Proanthocyanins','Color_intensity','Hue','OD280_OD315_of_diluted_wines','Proline')

Ahora tenemos nuestro dataframe con los nombres de las variables suficientemente descriptivos.

A continuación vamos a comprobar si todas las observaciones están completas y no falta ningún valor. Si se diera el caso, tendriamos que decidir si omitir las observaciones incompletas o completar las mismas con valores promedio.

colSums(is.na(data))
##                         Wine                      Alcohol 
##                            0                            0 
##                   Malic_acid                          Ash 
##                            0                            0 
##               Alcalinity_ash                    Magnesium 
##                            0                            0 
##                Total_phenols                   Flavanoids 
##                            0                            0 
##       Nonflavanoinds_phenols              Proanthocyanins 
##                            0                            0 
##              Color_intensity                          Hue 
##                            0                            0 
## OD280_OD315_of_diluted_wines                      Proline 
##                            0                            0

Podemos comprobar como en nuestro dataframe todas las observaciones están completas y por tanto continuar con el análisis.

Observamos como los rangos de valor de cada una de las variables son muy diferentes en magnitud. Esto requerirá que los mismos sean normalizados, acción que llevaremos a cabo en el siguiente apartado.

summary(data)
##       Wine          Alcohol        Malic_acid         Ash       
##  Min.   :1.000   Min.   :11.03   Min.   :0.740   Min.   :1.360  
##  1st Qu.:1.000   1st Qu.:12.36   1st Qu.:1.603   1st Qu.:2.210  
##  Median :2.000   Median :13.05   Median :1.865   Median :2.360  
##  Mean   :1.938   Mean   :13.00   Mean   :2.336   Mean   :2.367  
##  3rd Qu.:3.000   3rd Qu.:13.68   3rd Qu.:3.083   3rd Qu.:2.558  
##  Max.   :3.000   Max.   :14.83   Max.   :5.800   Max.   :3.230  
##  Alcalinity_ash    Magnesium      Total_phenols     Flavanoids   
##  Min.   :10.60   Min.   : 70.00   Min.   :0.980   Min.   :0.340  
##  1st Qu.:17.20   1st Qu.: 88.00   1st Qu.:1.742   1st Qu.:1.205  
##  Median :19.50   Median : 98.00   Median :2.355   Median :2.135  
##  Mean   :19.49   Mean   : 99.74   Mean   :2.295   Mean   :2.029  
##  3rd Qu.:21.50   3rd Qu.:107.00   3rd Qu.:2.800   3rd Qu.:2.875  
##  Max.   :30.00   Max.   :162.00   Max.   :3.880   Max.   :5.080  
##  Nonflavanoinds_phenols Proanthocyanins Color_intensity       Hue        
##  Min.   :0.1300         Min.   :0.410   Min.   : 1.280   Min.   :0.4800  
##  1st Qu.:0.2700         1st Qu.:1.250   1st Qu.: 3.220   1st Qu.:0.7825  
##  Median :0.3400         Median :1.555   Median : 4.690   Median :0.9650  
##  Mean   :0.3619         Mean   :1.591   Mean   : 5.058   Mean   :0.9574  
##  3rd Qu.:0.4375         3rd Qu.:1.950   3rd Qu.: 6.200   3rd Qu.:1.1200  
##  Max.   :0.6600         Max.   :3.580   Max.   :13.000   Max.   :1.7100  
##  OD280_OD315_of_diluted_wines    Proline      
##  Min.   :1.270                Min.   : 278.0  
##  1st Qu.:1.938                1st Qu.: 500.5  
##  Median :2.780                Median : 673.5  
##  Mean   :2.612                Mean   : 746.9  
##  3rd Qu.:3.170                3rd Qu.: 985.0  
##  Max.   :4.000                Max.   :1680.0

También podemos visualizar gráficamente la relación existente entre cada una de las variables y si dicha relación podría ayudarnos a clasificar el tipo de vino.

pairs(data[2:14], col=data$Wine)

A la vista de los resultados, puede resultar interesante hacer un zoom específico sobre alguna de estas relacciones para visualizar con más detalle el agrupamiento que se produce:

library(ggplot2)
qplot(Alcohol, Flavanoids, data=data, color=factor(data$Wine), geom=c("point","smooth"), main = "Flavanoids and Alcohol in Wine")
## `geom_smooth()` using method = 'loess'

qplot(Hue, Proline, data=data, color=factor(data$Wine), geom=c("point","smooth"), main = "Proline and Hue in Wine")
## `geom_smooth()` using method = 'loess'

En ambos gráficos puede observarse una interesante dependencia de las variables usadas que podrían determinar el tipo de vino. Del mismo modo también sería posible a través de estos análisis previos eliminar alguna variable si consideramos que la misma no colabora positivamente a realizar la clasificación.

Normalización

Durante el proceso de normalizado tenemos que abordar dos tareas diferente:

Normalización de Categorías

Para abordar esta tarea haremos dummy coding con la ayuda del paquete R dummies y crearemos un nuevo dataframe.

library(dummies)
## dummies-1.5.6 provided by Decision Patterns
data_dm <- dummy.data.frame(data=data, names="Wine", sep="_")
str(data_dm)
## 'data.frame':    178 obs. of  16 variables:
##  $ Wine_1                      : int  1 1 1 1 1 1 1 1 1 1 ...
##  $ Wine_2                      : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ Wine_3                      : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ Alcohol                     : num  14.2 13.2 13.2 14.4 13.2 ...
##  $ Malic_acid                  : num  1.71 1.78 2.36 1.95 2.59 1.76 1.87 2.15 1.64 1.35 ...
##  $ Ash                         : num  2.43 2.14 2.67 2.5 2.87 2.45 2.45 2.61 2.17 2.27 ...
##  $ Alcalinity_ash              : num  15.6 11.2 18.6 16.8 21 15.2 14.6 17.6 14 16 ...
##  $ Magnesium                   : int  127 100 101 113 118 112 96 121 97 98 ...
##  $ Total_phenols               : num  2.8 2.65 2.8 3.85 2.8 3.27 2.5 2.6 2.8 2.98 ...
##  $ Flavanoids                  : num  3.06 2.76 3.24 3.49 2.69 3.39 2.52 2.51 2.98 3.15 ...
##  $ Nonflavanoinds_phenols      : num  0.28 0.26 0.3 0.24 0.39 0.34 0.3 0.31 0.29 0.22 ...
##  $ Proanthocyanins             : num  2.29 1.28 2.81 2.18 1.82 1.97 1.98 1.25 1.98 1.85 ...
##  $ Color_intensity             : num  5.64 4.38 5.68 7.8 4.32 6.75 5.25 5.05 5.2 7.22 ...
##  $ Hue                         : num  1.04 1.05 1.03 0.86 1.04 1.05 1.02 1.06 1.08 1.01 ...
##  $ OD280_OD315_of_diluted_wines: num  3.92 3.4 3.17 3.45 2.93 2.85 3.58 3.58 2.85 3.55 ...
##  $ Proline                     : int  1065 1050 1185 1480 735 1450 1290 1295 1045 1045 ...
##  - attr(*, "dummies")=List of 1
##   ..$ Wine: int  1 2 3

Normalización de Variables Numéricas

Para la normalización de los valores de la variables definiremos la siguiente función:

normaliza <- function(x) {return ((x-min(x))/(max(x)-min(x)))}

Ahora aplicaremos la función creada normaliza a cada una de las columna.

data_norm <- as.data.frame(lapply(data_dm, normaliza))

Por último revisamos los valores normalizados de las variables que se encontrarán en todos los casos entre 0 y 1.

summary(data_norm)
##      Wine_1           Wine_2           Wine_3          Alcohol      
##  Min.   :0.0000   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.3507  
##  Median :0.0000   Median :0.0000   Median :0.0000   Median :0.5316  
##  Mean   :0.3315   Mean   :0.3989   Mean   :0.2697   Mean   :0.5186  
##  3rd Qu.:1.0000   3rd Qu.:1.0000   3rd Qu.:1.0000   3rd Qu.:0.6967  
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
##    Malic_acid          Ash         Alcalinity_ash     Magnesium     
##  Min.   :0.0000   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.1705   1st Qu.:0.4545   1st Qu.:0.3402   1st Qu.:0.1957  
##  Median :0.2223   Median :0.5348   Median :0.4588   Median :0.3043  
##  Mean   :0.3155   Mean   :0.5382   Mean   :0.4585   Mean   :0.3233  
##  3rd Qu.:0.4629   3rd Qu.:0.6404   3rd Qu.:0.5619   3rd Qu.:0.4022  
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
##  Total_phenols      Flavanoids     Nonflavanoinds_phenols Proanthocyanins 
##  Min.   :0.0000   Min.   :0.0000   Min.   :0.0000         Min.   :0.0000  
##  1st Qu.:0.2629   1st Qu.:0.1825   1st Qu.:0.2642         1st Qu.:0.2650  
##  Median :0.4741   Median :0.3787   Median :0.3962         Median :0.3612  
##  Mean   :0.4535   Mean   :0.3564   Mean   :0.4375         Mean   :0.3725  
##  3rd Qu.:0.6276   3rd Qu.:0.5348   3rd Qu.:0.5802         3rd Qu.:0.4858  
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000         Max.   :1.0000  
##  Color_intensity       Hue         OD280_OD315_of_diluted_wines
##  Min.   :0.0000   Min.   :0.0000   Min.   :0.0000              
##  1st Qu.:0.1655   1st Qu.:0.2459   1st Qu.:0.2445              
##  Median :0.2910   Median :0.3943   Median :0.5531              
##  Mean   :0.3224   Mean   :0.3882   Mean   :0.4915              
##  3rd Qu.:0.4198   3rd Qu.:0.5203   3rd Qu.:0.6960              
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000              
##     Proline      
##  Min.   :0.0000  
##  1st Qu.:0.1587  
##  Median :0.2821  
##  Mean   :0.3344  
##  3rd Qu.:0.5043  
##  Max.   :1.0000

Entrenador y test

En este paso dividiremos nuestro conjunto de datos en dos subconjuntos, una para entrenar la red neuronal que denominaremos train, y otro para testearla que denominamos test. Para llevarlo a cabo haremos uso de la función R sample que nos extraerá una muestra aleatoria de registros para crear cada uno de los conjuntos de entrenamiento y test.

set.seed(3141592) #necesario si se quiere reproducir de nuevo el mismo código y obtener los mimsos resultados
index <- sample(nrow(data_norm), round(0.75*nrow(data_norm)))
train <- data_norm[index,] # crea train a partir del indice de la muestra
test <- data_norm[-index,] # crea test a partir del resto de la muestra
head(train)
##     Wine_1 Wine_2 Wine_3   Alcohol Malic_acid       Ash Alcalinity_ash
## 103      0      1      0 0.3447368 0.33794466 0.5882353      0.5360825
## 21       1      0      0 0.7973684 0.17588933 0.4919786      0.2783505
## 85       0      1      0 0.2131579 0.02964427 0.6524064      0.3814433
## 46       1      0      0 0.8368421 0.65217391 0.5775401      0.4278351
## 73       0      1      0 0.6473684 0.18181818 0.4705882      0.6907216
## 43       1      0      0 0.7500000 0.22727273 0.6577540      0.2268041
##     Magnesium Total_phenols Flavanoids Nonflavanoinds_phenols
## 103 0.3043478     0.5448276  0.3734177              0.3962264
## 21  0.6086957     0.6965517  0.5970464              0.2075472
## 85  0.2608696     0.4206897  0.3945148              0.1698113
## 46  0.4456522     0.6448276  0.4873418              0.3207547
## 73  0.1847826     0.3103448  0.3164557              0.2641509
## 43  0.3369565     0.7827586  0.6793249              0.0754717
##     Proanthocyanins Color_intensity       Hue OD280_OD315_of_diluted_wines
## 103       0.2839117       0.1296928 0.2601626                    0.7728938
## 21        0.5331230       0.3728669 0.4959350                    0.8937729
## 85        0.6119874       0.1510239 0.2520325                    0.6630037
## 46        0.2649842       0.3378840 0.3170732                    0.7545788
## 73        0.1955836       0.2098976 0.4065041                    0.5531136
## 43        0.4069401       0.3540956 0.3252033                    0.8388278
##       Proline
## 103 0.1141227
## 21  0.3580599
## 85  0.1726106
## 46  0.5720399
## 73  0.1383738
## 43  0.5827389
head(test)
##    Wine_1 Wine_2 Wine_3   Alcohol Malic_acid       Ash Alcalinity_ash
## 1       1      0      0 0.8421053  0.1916996 0.5721925      0.2577320
## 4       1      0      0 0.8789474  0.2391304 0.6096257      0.3195876
## 9       1      0      0 1.0000000  0.1778656 0.4331551      0.1752577
## 10      1      0      0 0.7447368  0.1205534 0.4866310      0.2783505
## 18      1      0      0 0.7368421  0.1640316 0.6737968      0.4845361
## 19      1      0      0 0.8315789  0.1679842 0.5989305      0.3041237
##    Magnesium Total_phenols Flavanoids Nonflavanoinds_phenols
## 1  0.6195652     0.6275862  0.5738397              0.2830189
## 4  0.4673913     0.9896552  0.6645570              0.2075472
## 9  0.2934783     0.6275862  0.5569620              0.3018868
## 10 0.3043478     0.6896552  0.5928270              0.1698113
## 18 0.4891304     0.6793103  0.6455696              0.5094340
## 19 0.4130435     0.8000000  0.7573840              0.3584906
##    Proanthocyanins Color_intensity       Hue OD280_OD315_of_diluted_wines
## 1        0.5930599       0.3720137 0.4552846                    0.9706960
## 4        0.5583596       0.5563140 0.3089431                    0.7985348
## 9        0.4952681       0.3344710 0.4878049                    0.5787546
## 10       0.4542587       0.5068259 0.4308943                    0.8351648
## 18       0.4132492       0.4539249 0.5284553                    0.4761905
## 19       0.4574132       0.6331058 0.6097561                    0.5677656
##      Proline
## 1  0.5613409
## 4  0.8573466
## 9  0.5470756
## 10 0.5470756
## 18 0.6077033
## 19 1.0000000

ANN: Entrenamiento

Ha llegado el momento de entrenar a nuestra red neuronal. Vamos a probar con la siguiente topología de diseño:

Además repetiremos el entrenamiento 5 veces con objeto de quedars con la versión qu obtenga una menor aerror.

set.seed(3141592)
library(neuralnet)
ann_model <- neuralnet(Wine_1+Wine_2+Wine_3~Alcohol+Malic_acid+Ash+Alcalinity_ash+Magnesium+Total_phenols+Flavanoids+Nonflavanoinds_phenols+Proanthocyanins+Color_intensity+Hue+OD280_OD315_of_diluted_wines+Proline, data=train, hidden=c(5,5), lifesign = "minimal", linear.output = FALSE, rep =10)
## hidden: 5, 5    thresh: 0.01    rep:  1/10    steps:      94 error: 0.01165  time: 0.08 secs
## hidden: 5, 5    thresh: 0.01    rep:  2/10    steps:      85 error: 0.01288  time: 0.11 secs
## hidden: 5, 5    thresh: 0.01    rep:  3/10    steps:      95 error: 0.00756  time: 0.02 secs
## hidden: 5, 5    thresh: 0.01    rep:  4/10    steps:      86 error: 0.00643  time: 0.02 secs
## hidden: 5, 5    thresh: 0.01    rep:  5/10    steps:      93 error: 0.01692  time: 0.03 secs
## hidden: 5, 5    thresh: 0.01    rep:  6/10    steps:     114 error: 0.0059   time: 0.03 secs
## hidden: 5, 5    thresh: 0.01    rep:  7/10    steps:      72 error: 0.01431  time: 0.02 secs
## hidden: 5, 5    thresh: 0.01    rep:  8/10    steps:      79 error: 0.02378  time: 0.02 secs
## hidden: 5, 5    thresh: 0.01    rep:  9/10    steps:     118 error: 0.01621  time: 0.03 secs
## hidden: 5, 5    thresh: 0.01    rep: 10/10    steps:     106 error: 0.01384  time: 0.03 secs
plot(ann_model, rep="best")

ANN: Test

A continuación vamos a determinar la validez del modelo aplicándolo al grupo de datos test y echando un vistazo a los resultados de la predicción.

ann_pred <- compute(ann_model,test[,4:16])
head(ann_pred$net.result)
##            [,1]            [,2]             [,3]
## 1  0.9997126354 0.0002974652830 0.00003118417630
## 4  0.9998115596 0.0001865797449 0.00003961208542
## 9  0.9997440051 0.0002612252912 0.00003317438964
## 10 0.9996983231 0.0003105456212 0.00003189655690
## 18 0.9993760497 0.0005844696221 0.00003680862309
## 19 0.9997950501 0.0002054798920 0.00003718932602

Redondeemos los resultados para tener un resultado claro en el clasificador.

ann_pred_round <- as.data.frame(round(ann_pred$net.result))
head(ann_pred_round, 10)
##    V1 V2 V3
## 1   1  0  0
## 4   1  0  0
## 9   1  0  0
## 10  1  0  0
## 18  1  0  0
## 19  1  0  0
## 26  0  1  0
## 30  1  0  0
## 31  1  0  0
## 32  1  0  0

Evaluación del modelo

Ahora necesitamos comparar la predicción de la clasificación con la clasificación real del conjunto test. Haremos la correspondiente tabla de contingencia con la ayuda del paquete R gmodels. Primero vamos a convertir en un vector la predicción:

predic<-max.col(ann_pred_round)
head(predic, 100)
##  [1] 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3
## [36] 3 3 3 3 3 3 3 3 3

Seguidamente hacemos lo mismo con la clasificación correcta de test

test_res <- max.col(test[,1:3])
head(test_res, 100)
##  [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3
## [36] 3 3 3 3 3 3 3 3 3

Ambos vectores tienen el mismo número de elementos como cabría esperar. Ahora ya podemos montar la matriz de contigencia.

library(gmodels)
CrossTable(x=test_res, y=predic)
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## | Chi-square contribution |
## |           N / Row Total |
## |           N / Col Total |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  44 
## 
##  
##              | predic 
##     test_res |         1 |         2 |         3 | Row Total | 
## -------------|-----------|-----------|-----------|-----------|
##            1 |        14 |         1 |         0 |        15 | 
##              |    17.839 |     3.638 |     4.773 |           | 
##              |     0.933 |     0.067 |     0.000 |     0.341 | 
##              |     1.000 |     0.062 |     0.000 |           | 
##              |     0.318 |     0.023 |     0.000 |           | 
## -------------|-----------|-----------|-----------|-----------|
##            2 |         0 |        15 |         0 |        15 | 
##              |     4.773 |    16.705 |     4.773 |           | 
##              |     0.000 |     1.000 |     0.000 |     0.341 | 
##              |     0.000 |     0.938 |     0.000 |           | 
##              |     0.000 |     0.341 |     0.000 |           | 
## -------------|-----------|-----------|-----------|-----------|
##            3 |         0 |         0 |        14 |        14 | 
##              |     4.455 |     5.091 |    20.455 |           | 
##              |     0.000 |     0.000 |     1.000 |     0.318 | 
##              |     0.000 |     0.000 |     1.000 |           | 
##              |     0.000 |     0.000 |     0.318 |           | 
## -------------|-----------|-----------|-----------|-----------|
## Column Total |        14 |        16 |        14 |        44 | 
##              |     0.318 |     0.364 |     0.318 |           | 
## -------------|-----------|-----------|-----------|-----------|
## 
## 

En esta ocasión el modelo nos ha dado una precisión bastante buena ya que la misma ha sido correcta en un 97,72% y por tanto el error cometido es del 2,28%. En terminos cuantitativo, de las 44 observaciones realizadas en el test, 43 fueron correctamente clasificadas y se falló en 1.

Si queremos is un poco más allá en la evaluación de nuestra red neuronal podemos recurrir al paquete R caret que nos genera igualmente una matriz de confusión además de otros indicadores más avanzados sobre la bondad del modelo.

library(caret)
## Loading required package: lattice
confusionMatrix(table(test_res, predic)) #usar table para que caret sea capaz de generar la matriz de confusión. Si se omite da error.
## Confusion Matrix and Statistics
## 
##         predic
## test_res  1  2  3
##        1 14  1  0
##        2  0 15  0
##        3  0  0 14
## 
## Overall Statistics
##                                                   
##                Accuracy : 0.9772727               
##                  95% CI : (0.8797584, 0.9994248)  
##     No Information Rate : 0.3636364               
##     P-Value [Acc > NIR] : < 0.00000000000000022204
##                                                   
##                   Kappa : 0.9658915               
##  Mcnemar's Test P-Value : NA                      
## 
## Statistics by Class:
## 
##                       Class: 1  Class: 2  Class: 3
## Sensitivity          1.0000000 0.9375000 1.0000000
## Specificity          0.9666667 1.0000000 1.0000000
## Pos Pred Value       0.9333333 1.0000000 1.0000000
## Neg Pred Value       1.0000000 0.9655172 1.0000000
## Prevalence           0.3181818 0.3636364 0.3181818
## Detection Rate       0.3181818 0.3409091 0.3181818
## Detection Prevalence 0.3409091 0.3409091 0.3181818
## Balanced Accuracy    0.9833333 0.9687500 1.0000000

Cross-Validation

Aun podemos profundizar más sobre la precisión del modelo aplicando la técnica cross-validation en la elección de los conjuntos de train y test. De esta manera minimizamos el efecto de la elección del conjunto de datos a la hora de determinar la bondad del modelo.

En primer lugar creamos 10 sobres de datos folds.

set.seed(3141592)
myfolds <- createFolds(data$Wine, k=10)

A continuación creamos la función responsable de llevar a cabo la cross_validation sobre cada uno de los subconjuntos de datos folds y devolver la precisión y la kappa del modelo.

CV_ANN <- function(x)
{
  train <- data_norm[-x,]
  test <- data_norm[x,]
  ann_model <- neuralnet(Wine_1+Wine_2+Wine_3~Alcohol+Malic_acid+Ash+Alcalinity_ash+Magnesium+Total_phenols+Flavanoids+Nonflavanoinds_phenols+Proanthocyanins+Color_intensity+Hue+OD280_OD315_of_diluted_wines+Proline, data=train, hidden=c(5,5), lifesign = "minimal", linear.output = FALSE, rep =1)
  ann_pred <- compute(ann_model,test[,4:16])
  ann_pred_round <- as.data.frame(round(ann_pred$net.result))
  predic <- max.col(ann_pred_round)
  test_res <- max.col(test[,1:3])
  myconfusion <- confusionMatrix(table(test_res, predic))
  return (myconfusion$overall[1:2]) # el rango [1:2] corresponde a los índices donde se encuentran las variables Accuracy y Kappa dentro de nuestra matriz de confusión
}

Finalmente aplicamos la función a cada uno de los conjuntos de datos folds y guardamos los resultados en una lista.

CV_result <- lapply(myfolds, CV_ANN)
CV_result_DF <-as.data.frame(CV_result)
CV_result_DF
##          Fold01 Fold02 Fold03 Fold04 Fold05       Fold06 Fold07 Fold08
## Accuracy      1      1      1      1      1 0.9444444444      1      1
## Kappa         1      1      1      1      1 0.9154929577      1      1
##                Fold09       Fold10
## Accuracy 0.9444444444 0.9444444444
## Kappa    0.9130434783 0.9130434783

Para finalizar calculamos la media de los valores obtenidos de precisión y kappa.

rowMeans(CV_result_DF)
##     Accuracy        Kappa 
## 0.9833333333 0.9741579914

Nuestro clasificador ha logrado una precisión del 98,3% y una Kappa del 97,4% que podemos calificar como de muy buen resultado para un sencillo clasificator de vinos. Por tanto, para futuras clasificaciones podremos usar la red neural almacenada en ann_model en su version best.

plot(ann_model, rep="best")