file <- "https://raw.githubusercontent.com/fhernanb/datos/master/propelente"
datos <- read.table(file=file, header=TRUE) 
head(datos)
##   Resistencia  Edad
## 1     2158.70 15.50
## 2     1678.15 23.75
## 3     2316.00  8.00
## 4     2061.30 17.00
## 5     2207.50  5.50
## 6     1708.30 19.00
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 4.2.3
ggplot(datos, aes(x = Edad, y = Resistencia))+geom_point()

library(neuralnet)
## Warning: package 'neuralnet' was built under R version 4.2.3

Creación de la red neuronal Antes de crear la red es necesario escalar las variables para evitar el efecto de la escala de las variables. Existen varias formas de escalar pero se usará una transformación para pasar los valores de las variables al intervalo (0,1).

Con el siguiente código se va convertir los datos originales a datos escalados y se almacenarán en el objeto scaled.

maxs <- apply(datos, 2, max) # Máximo valor de las variables
mins <- apply(datos, 2, min) # Mínimo valor de las variables
scaled <- as.data.frame(scale(datos, center=mins, scale=maxs-mins))
head(cbind(datos, scaled))
##   Resistencia  Edad Resistencia      Edad
## 1     2158.70 15.50  0.49234158 0.5869565
## 2     1678.15 23.75  0.00000000 0.9456522
## 3     2316.00  8.00  0.65350136 0.2608696
## 4     2061.30 17.00  0.39255161 0.6521739
## 5     2207.50  5.50  0.54233902 0.1521739
## 6     1708.30 19.00  0.03088981 0.7391304
summary(scaled)
##   Resistencia          Edad       
##  Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.1079   1st Qu.:0.2228  
##  Median :0.5171   Median :0.4674  
##  Mean   :0.4643   Mean   :0.4940  
##  3rd Qu.:0.6802   3rd Qu.:0.7663  
##  Max.   :1.0000   Max.   :1.0000
library(gridExtra)
## Warning: package 'gridExtra' was built under R version 4.2.3
p1 <- ggplot(datos, aes(x = Edad, y = Resistencia))+geom_point()
p2 <- ggplot(scaled, aes(x = Edad, y = Resistencia))+geom_point()
grid.arrange(p1, p2, ncol = 2)

set.seed(12345)
mod1 <- neuralnet(Resistencia ~ Edad, data = scaled, hidden = c(1), threshold=0.01)

hidden: es un vector de enteros que especifica el número de neuronas ocultas (vértices) en cada capa. threshold: un valor numérico que especifica el umbral para los derivados parciales de la función de error como criterio de parada. Además, se puede construir un dibujo con la red ajustada usando la función plot sobre el objeto mod1, véase:

plot(mod1, rep="best")

mod1$act.fct # Activation function
## function (x) 
## {
##     1/(1 + exp(-x))
## }
## <bytecode: 0x0000018704065770>
## <environment: 0x0000018704064e78>
## attr(,"type")
## [1] "logistic"
unlist(mod1$weights)  # Obtener en formas de vector los weigths=pesos
## [1]  0.9211313 -2.2825075 -0.2875935  1.6535977

Predicción

En este ejemplo la base de datos tiene solo 20 observaciones y por esta razón el conjunto de entrenamiento y conjunto de prueba son el mismo.

En el código mostrado a continuación se crea el conjunto de prueba test solo con la covariable Edad proveniente de la base scaled. La función compute permite predecir los valores Resistencia para la informacion disponible en test teniendo como referencia una red neuronal entrenada, en este caso vamos a usar mod1.

test <- data.frame(Edad = scaled$Edad)
test
##          Edad
## 1  0.58695652
## 2  0.94565217
## 3  0.26086957
## 4  0.65217391
## 5  0.15217391
## 6  0.73913043
## 7  0.95652174
## 8  0.02173913
## 9  0.23913043
## 10 0.39130435
## 11 0.47826087
## 12 0.07608696
## 13 1.00000000
## 14 0.33695652
## 15 0.86956522
## 16 0.69565217
## 17 0.17391304
## 18 0.45652174
## 19 0.00000000
## 20 0.84782609
myprediction <- compute(x=mod1, covariate=test)
myprediction$net.result[1:5]
## [1] 0.36863916 0.08430328 0.67266955 0.31071614 0.77011036

El elemento $net.result del objeto myprediction tiene la respuesta estimada pero en la forma escalada, por esta razón es necesario aplicar la transformación inversa para obtener el resultado en la escala original. A continuación el código necesario para retornar a la escala original.

yhat_red <- myprediction$net.result * (max(datos$Resistencia)-min(datos$Resistencia))+min(datos$Resistencia)
datos$yhat_red <- yhat_red
yhat_red[1:5] 
## [1] 2037.960 1760.434 2334.709 1981.424 2429.816
ggplot(datos, aes(x=Resistencia, y=yhat_red)) + geom_point() +
  geom_abline(intercept=0, slope=1, color="blue", linetype="dashed", size=1)
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

mod2 <- neuralnet(Resistencia ~ Edad, data=scaled, 
                  hidden=c(2), threshold=0.01)
plot(mod2, rep = "best")

mod3 <- neuralnet(Resistencia ~ Edad, data=scaled, 
                  hidden=c(2, 3), threshold=0.01)
plot(mod3, rep="best")

luis1 <- neuralnet(Resistencia ~ Edad, data=scaled, 
                  hidden=c(5, 5, 5, 5), threshold=0.001)
plot(luis1, rep = "best")

luis1$result.matrix
##                                   [,1]
## error                     1.872673e-02
## reached.threshold         9.496481e-04
## steps                     2.482000e+03
## Intercept.to.1layhid1     1.353679e+00
## Edad.to.1layhid1         -5.258114e+00
## Intercept.to.1layhid2     3.690259e+00
## Edad.to.1layhid2          3.511034e+00
## Intercept.to.1layhid3    -1.323108e+00
## Edad.to.1layhid3          3.193331e+01
## Intercept.to.1layhid4    -6.430845e+00
## Edad.to.1layhid4          1.636461e+01
## Intercept.to.1layhid5     4.408880e-01
## Edad.to.1layhid5         -5.272489e+00
## Intercept.to.2layhid1     1.562933e+00
## 1layhid1.to.2layhid1     -1.808430e+00
## 1layhid2.to.2layhid1      9.239058e-01
## 1layhid3.to.2layhid1     -2.694796e+00
## 1layhid4.to.2layhid1     -1.360308e+00
## 1layhid5.to.2layhid1     -1.006264e+00
## Intercept.to.2layhid2     9.718105e-01
## 1layhid1.to.2layhid2     -4.709215e+00
## 1layhid2.to.2layhid2     -1.459366e+00
## 1layhid3.to.2layhid2      1.954213e+01
## 1layhid4.to.2layhid2      4.450427e+00
## 1layhid5.to.2layhid2      5.033692e-01
## Intercept.to.2layhid3    -1.249576e+00
## 1layhid1.to.2layhid3      6.176124e+01
## 1layhid2.to.2layhid3     -1.315281e+00
## 1layhid3.to.2layhid3     -3.683942e-01
## 1layhid4.to.2layhid3     -2.852663e+00
## 1layhid5.to.2layhid3      4.100069e+01
## Intercept.to.2layhid4     4.401733e-01
## 1layhid1.to.2layhid4     -2.339533e+00
## 1layhid2.to.2layhid4      1.894067e+00
## 1layhid3.to.2layhid4     -2.922802e+00
## 1layhid4.to.2layhid4     -8.544833e-01
## 1layhid5.to.2layhid4      1.585064e-01
## Intercept.to.2layhid5     1.976112e-01
## 1layhid1.to.2layhid5      3.974441e-01
## 1layhid2.to.2layhid5     -7.270210e-01
## 1layhid3.to.2layhid5     -1.914132e+00
## 1layhid4.to.2layhid5      9.714482e-01
## 1layhid5.to.2layhid5     -1.298724e+00
## Intercept.to.3layhid1     5.591862e-01
## 2layhid1.to.3layhid1     -2.919869e+00
## 2layhid2.to.3layhid1      6.538219e-02
## 2layhid3.to.3layhid1     -1.285745e+00
## 2layhid4.to.3layhid1     -1.357096e+00
## 2layhid5.to.3layhid1      2.168412e+00
## Intercept.to.3layhid2     1.450016e+00
## 2layhid1.to.3layhid2      4.238290e+00
## 2layhid2.to.3layhid2      9.246122e-01
## 2layhid3.to.3layhid2     -1.874908e+00
## 2layhid4.to.3layhid2      6.079382e-01
## 2layhid5.to.3layhid2     -7.090449e-01
## Intercept.to.3layhid3     2.730243e-02
## 2layhid1.to.3layhid3      1.528167e+00
## 2layhid2.to.3layhid3     -9.224517e-01
## 2layhid3.to.3layhid3      2.709117e+00
## 2layhid4.to.3layhid3      6.882306e-01
## 2layhid5.to.3layhid3     -3.808138e+00
## Intercept.to.3layhid4     6.806519e-01
## 2layhid1.to.3layhid4      1.803466e-01
## 2layhid2.to.3layhid4      5.950632e-01
## 2layhid3.to.3layhid4      1.742603e+01
## 2layhid4.to.3layhid4     -8.531162e-01
## 2layhid5.to.3layhid4      5.579552e+00
## Intercept.to.3layhid5     3.183076e+00
## 2layhid1.to.3layhid5     -7.401655e+00
## 2layhid2.to.3layhid5     -9.543732e-01
## 2layhid3.to.3layhid5      6.744790e-01
## 2layhid4.to.3layhid5     -7.788762e-01
## 2layhid5.to.3layhid5      5.202928e+00
## Intercept.to.4layhid1     8.333141e-01
## 3layhid1.to.4layhid1      4.353984e+00
## 3layhid2.to.4layhid1     -4.372944e-01
## 3layhid3.to.4layhid1     -1.154812e+01
## 3layhid4.to.4layhid1      1.736486e+00
## 3layhid5.to.4layhid1     -1.357283e+00
## Intercept.to.4layhid2     2.302392e-01
## 3layhid1.to.4layhid2     -9.307255e-01
## 3layhid2.to.4layhid2     -2.125904e-01
## 3layhid3.to.4layhid2      1.686561e+00
## 3layhid4.to.4layhid2     -3.196114e-01
## 3layhid5.to.4layhid2      4.104426e-01
## Intercept.to.4layhid3    -1.134749e+00
## 3layhid1.to.4layhid3     -1.485353e+00
## 3layhid2.to.4layhid3      1.350050e+00
## 3layhid3.to.4layhid3      1.419199e+00
## 3layhid4.to.4layhid3      4.783948e-02
## 3layhid5.to.4layhid3     -4.162780e-01
## Intercept.to.4layhid4     2.741253e-01
## 3layhid1.to.4layhid4      3.084084e-01
## 3layhid2.to.4layhid4      6.039901e+00
## 3layhid3.to.4layhid4      1.245033e+00
## 3layhid4.to.4layhid4      1.657360e+00
## 3layhid5.to.4layhid4     -5.977450e+00
## Intercept.to.4layhid5     2.207307e-01
## 3layhid1.to.4layhid5      2.888739e+00
## 3layhid2.to.4layhid5     -5.229474e-01
## 3layhid3.to.4layhid5     -1.266610e+00
## 3layhid4.to.4layhid5     -1.938977e+00
## 3layhid5.to.4layhid5      1.024676e+00
## Intercept.to.Resistencia -1.991872e+00
## 4layhid1.to.Resistencia   1.119023e+00
## 4layhid2.to.Resistencia   1.777216e+00
## 4layhid3.to.Resistencia   1.330327e+00
## 4layhid4.to.Resistencia   6.565543e-01
## 4layhid5.to.Resistencia  -1.103547e+00
luis1$response
##    Resistencia
## 1   0.49234158
## 2   0.00000000
## 3   0.65350136
## 4   0.39255161
## 5   0.54233902
## 6   0.03088981
## 7   0.10916449
## 8   0.91885662
## 9   0.69642949
## 10  0.59274627
## 11  0.49900108
## 12  0.73910148
## 13  0.10414425
## 14  0.67476051
## 15  0.08928846
## 16  0.38456022
## 17  0.75431586
## 18  0.53516726
## 19  1.00000000
## 20  0.07740382
myprediction <- compute(x=luis1, covariate=test)
yhat_red <- myprediction$net.result * (max(datos$Resistencia)-min(datos$Resistencia))+min(datos$Resistencia)
yhat_red[1:20]
##  [1] 2128.525 1753.602 2348.614 2115.842 2306.375 1725.408 1753.000 2587.214
##  [9] 2342.276 2272.387 2165.420 2386.045 1751.102 2330.982 1754.396 2013.755
## [17] 2313.038 2187.125 2644.552 1747.232
ggplot(datos, aes(x=Resistencia, y=yhat_red)) + geom_point() +
  geom_abline(intercept=0, slope=1, color="blue", linetype="dashed", size=1)