Esta entrada aprece en el Blog Punto de Acumulación

En esta entrada vamos a tratar otro tema habitual en el análisis de datos: el ajuste de los parámetros de los algoritmos de aprendizaje. Para los detalles teóricos, se puede consultar model training and tuning del paquete Caret.

(NOTA: En esta entrada se hace una búsqueda bastante exhaustiva en el espacio de parámetros y puede tardar bastante tiempo. Para una primera aproximación, se podría hacer una búsqueda con menos valores.)

Aquí iremos directamente al grano y ajustaremos cuatro de los parámetros de gbm: interaction.depth, n.trees, shrinkage y n.minobsinnode. Se va a realizar sobre un conjunto de datos de ejemplo, algae bloom o reproducción de algas, que nos descargaremos directamente de la web y que se pueden encontrar aqui.

La manera más habitual de ajustar estos parámetros es probar todas las combinaciones de valores de parámetros en lo que se llama una búsqueda en rejilla o grid search. Para cada combinacion de parámetros, se realiza una validacion cruzada. En este caso se han usado 5 folds, aunque si el tamaño de los datos es muy grande, puede ser conveniente usar sólo 2 folds, o directamente entrenamiento / validación. Podríamos realizar dicha búsqueda en rejilla mediante un cuadruple bucle anidado, pero en esta entrada usaremos algunos elementos más avanzados. En concreto:

A continuación se muestra el código:

options(width=500)
catappend = function(...,file="",header=FALSE) {
    fe <- file.exists(file)
    if(header && fe) warning(paste("catappend: File ", file," exists\n",immediate.=TRUE))
    #    if(header) cat(paste("#",Sys.time()), append=TRUE,fill=TRUE,file=file)
    if(!header || (header && !fe)) cat(..., append=TRUE,fill=TRUE,file=file)
}

# Function that returns Root Mean Squared Error
rmse = function(error) sqrt(mean(error^2))
 
# Function that returns Mean Absolute Error
mae = function(error) mean(abs(error))



library(plyr)          # para la función Splat
library(data.table)    # alternativa a data.frame 
library(iterators)     # iteradores
library(itertools)
library("foreach")
library("functional")  # programacion funcional
library(gbm)           # gradient boosting
library(caret)         # Mineria de datos

library("doMC")        # multicore
registerDoMC(4)        # Usa 4 cores

semilla = 1            # semilla aleatoria
rango_nFolds = 1:5     # Validación cruzada de 5 folds

# rango_n.trees = c(100, 1000, 5000, 10000)
# rangos_parametros = list(c(1, 2, 4, 6, 8, 10),
#                          c(0.001, 0.01, 0.05, 0.1, 0.5),
#                          c(2, 5, 10, 15, 20),
#                          rango_nFolds
#                          )

rango_n.trees = c(100, seq(500,15000, 500))
rangos_parametros = list(c(1, 2, 4, 6, 8, 10, 12, 14),
                         c(0.0001, 0.0005, 0.001, 0.01, 0.05, 0.1, 0.5),
                         c(2, 3, 4, 5, 10, 15, 20, 25, 30),
                         rango_nFolds
                         )


nombres_parametros = c("interaction.depth", "shrinkage", "n.minobsinnode", "nFold")
names(rangos_parametros) = nombres_parametros
claves = c(nombres_parametros, "n.trees")

# Cargamos los datos para hacer las validaciones cruzadas en cada punto de la rejilla. Al final cargaremos los de test para hacer el test final. 
load(url("http://www.dcc.fc.up.pt/~ltorgo/DataMiningWithR/DataSets/algae.RData"))
# La variable de salida (a1) la ponemos como primera columna para tenerla fácilmente accesible. De esa manera, datos[,1] es la salida y algae_a1[,-1] son las entradas. Este problema tenia múltiples salidas (a1 hasta a7). Descartamos todas menos a1. 
datos = subset(algae, select=c(a1, season:Chla))

funcion = gbm
formula = a1~.

# Creamos los folds de validacion cruzada con librería Caret
set.seed(semilla)
folds = createFolds(datos[,1], k = 5, list=FALSE)


# En este fichero iremos almacenando los resultados de entrenamiento y validación para cada combinación de parámetros y cada fold
nf = paste0("parametrosGBM-",semilla,".txt")

# Si el fichero en el que se guardan los resultados existe, cárgalo en memoria en una data.table y haz que los nombres de los parámetros sean las claves para buscar en dicha data.table. Si no, simplemente crea el fichero de resultados únicamente con la cabecera de los nombres de los atributos.

if(file.exists(nf)) {
  # Usamos unique en caso de que por la razón que sea, haya duplicados. Damos por supuesto que como se ha fijado la semilla aleatoria, los duplicados siempre van a dar el mismo resultado y se pueden elegir arbitrariamente
  parametros = fread(nf) 
  setkeyv(parametros, claves)
  parametros = unique(parametros)
} else {
  catappend(sprintf("%s Train Validacion Time", paste0(claves, collapse=" ")), file=nf, header=TRUE)
}



# Creamos el iterador. product genera todas las posibles combinaciones de parámetros de los rangos que hemos puesto. Es similar a expand.grid

it <- ihasNext(splat(product)(rangos_parametros))

# Por fin comienza el bucle principal donde se exploran todas las posibles combinaciones de parámetros. Aquellas combinaciones que ya hayan sido evaluadas y que estan almacenadas en el data.table parametros, no se ejecutarán. Esto se consigue mágicamente comprobando que nrow(parametros[as.data.table(p)])==0)
foreach(p = it) %:% when(!exists("parametros") || nrow(parametros)==0 ||  nrow(parametros[as.data.table(p), nomatch=0])==0) %do% {
  entrenamiento = datos[folds != p$nFold,]
  validacion = datos[folds == p$nFold,]

  # Aquí definimos la función que lanza el algoritmo de aprendizaje
  modela = splat(CurryL(funcion, formula=formula, data=entrenamiento, n.trees=max(rango_n.trees)))

  set.seed(semilla)
  comienzo = Sys.time()
  # Usa todos los parámetros en p, menos el último, porque es el número de fold
  modelo = modela(head(p,-1))
  fin = Sys.time()
  # Opcional: salvamos cada modelo por si tenemos que usarlo en el futuro. Hay que tener cuidado porque en algunos casos (como gbm), R incluye mucha informacion en el modelo. En el caso de gbm, se incluyen todos los datos de entrenamiento, con lo que el fichero puede llegar a ser bastante grande.
  # saveRDS(modelo, file=paste0("GBMmodel-", semilla, "-", paste0(unlist(p), collapse="-")))
  
  # Una vez construido el modelo, con el máximo de número de árboles, lo probamos con los valores de número de árboles menores al máximo. 
  for(n.trees in rango_n.trees){
    salidasTrain = predict(modelo, newdata = entrenamiento, n.trees=n.trees)
    salidasValidacion = predict(modelo, newdata = validacion, n.trees=n.trees)

    # Salvamos resultados en el fichero
    catappend(sprintf("%s %d %f %f %f",
                      paste0(p, collapse = " "),
                      n.trees, 
                      mae(salidasTrain-entrenamiento[[1]]),
                      mae(salidasValidacion-validacion[[1]]),
                      as.numeric(fin-comienzo)
                      ),
        file=nf)
   }
  }

# Si cada combinación de parámetros y folds tiene menos ejecuciones que número de diferentes valores de árboles

parametros = fread(nf)
setkeyv(parametros, claves)
parametros = unique(parametros)

if(nrow(parametros)<length(rango_n.trees)*prod(sapply(rangos_parametros,length))) cat("Algunas combinaciones de parámetros han fallado. Se recomienda volver a ejecutar el código")

El bucle anterior ha probado todas las posibles combinaciones de parámetros de gbm. Caso de que alguna combinación hubiera fallado en su ejecución (o se fue la corriente), basta con volver a ejecutar el código, el cual sólo ejecutará aquellas combinaciones de parámetros que queden por realizar. Tenemos los resultados para cada uno de los 5 folds, como se puede ver a continuación:

head(parametros,20)
##     interaction.depth shrinkage n.minobsinnode nFold n.trees Train Validacion   Time
##  1:                 1     1e-04              2     1     100 17.15      14.81 0.4847
##  2:                 1     1e-04              2     1     500 16.85      14.58 0.4847
##  3:                 1     1e-04              2     1    1000 16.49      14.28 0.4847
##  4:                 1     1e-04              2     1    1500 16.17      14.00 0.4847
##  5:                 1     1e-04              2     1    2000 15.86      13.72 0.4847
##  6:                 1     1e-04              2     1    2500 15.57      13.46 0.4847
##  7:                 1     1e-04              2     1    3000 15.29      13.21 0.4847
##  8:                 1     1e-04              2     1    3500 15.03      12.97 0.4847
##  9:                 1     1e-04              2     1    4000 14.79      12.73 0.4847
## 10:                 1     1e-04              2     1    4500 14.56      12.52 0.4847
## 11:                 1     1e-04              2     1    5000 14.34      12.33 0.4847
## 12:                 1     1e-04              2     1    5500 14.12      12.14 0.4847
## 13:                 1     1e-04              2     1    6000 13.92      11.97 0.4847
## 14:                 1     1e-04              2     1    6500 13.73      11.79 0.4847
## 15:                 1     1e-04              2     1    7000 13.54      11.63 0.4847
## 16:                 1     1e-04              2     1    7500 13.36      11.47 0.4847
## 17:                 1     1e-04              2     1    8000 13.19      11.30 0.4847
## 18:                 1     1e-04              2     1    8500 13.04      11.16 0.4847
## 19:                 1     1e-04              2     1    9000 12.89      11.01 0.4847
## 20:                 1     1e-04              2     1    9500 12.75      10.87 0.4847

Ahora calcularemos las medias de Train, Validacion y Time para los 5 folds asi:

# Hacemos la media de validacion cruzada
parametros = parametros[, list(Train=mean(Train), Validacion=mean(Validacion), Time=mean(Time)), by=eval(claves[claves!="nFold"])]
head(parametros, 20)
##     interaction.depth shrinkage n.minobsinnode n.trees Train Validacion   Time
##  1:                 1     1e-04              2     100 16.54      16.57 0.5627
##  2:                 1     1e-04              2     500 16.25      16.32 0.5627
##  3:                 1     1e-04              2    1000 15.90      16.03 0.5627
##  4:                 1     1e-04              2    1500 15.58      15.74 0.5627
##  5:                 1     1e-04              2    2000 15.29      15.48 0.5627
##  6:                 1     1e-04              2    2500 15.00      15.24 0.5627
##  7:                 1     1e-04              2    3000 14.74      15.01 0.5627
##  8:                 1     1e-04              2    3500 14.48      14.79 0.5627
##  9:                 1     1e-04              2    4000 14.25      14.57 0.5627
## 10:                 1     1e-04              2    4500 14.03      14.37 0.5627
## 11:                 1     1e-04              2    5000 13.81      14.18 0.5627
## 12:                 1     1e-04              2    5500 13.61      14.00 0.5627
## 13:                 1     1e-04              2    6000 13.42      13.82 0.5627
## 14:                 1     1e-04              2    6500 13.23      13.65 0.5627
## 15:                 1     1e-04              2    7000 13.06      13.49 0.5627
## 16:                 1     1e-04              2    7500 12.89      13.35 0.5627
## 17:                 1     1e-04              2    8000 12.73      13.20 0.5627
## 18:                 1     1e-04              2    8500 12.58      13.07 0.5627
## 19:                 1     1e-04              2    9000 12.43      12.95 0.5627
## 20:                 1     1e-04              2    9500 12.29      12.83 0.5627

La mejor combinación de parámetros corresponde a:

(p = parametros[which.min(Validacion)])
##    interaction.depth shrinkage n.minobsinnode n.trees Train Validacion  Time
## 1:                 4      0.05             10     100 7.227      10.51 1.645
p = p[,1:(ncol(p)-3), with=FALSE] # Quitamos los valores de Train, Validación y Time que no son parámetros realmente

Ahora construimos el modelo con todos los datos y los mejores parámetros y lo probamos con los datos de test, que nos traemos de la web. El Mean Absolute Error con los datos de test se puede ver al final. No es muy distinto a lo que nos salía antes en validación.

load(url("http://www.dcc.fc.up.pt/~ltorgo/DataMiningWithR/DataSets/testAlgae.RData"))
load(url("http://www.dcc.fc.up.pt/~ltorgo/DataMiningWithR/DataSets/algaeSols.RData"))
datosTest = cbind(a1=algae.sols$a1, test.algae)

modela = splat(CurryL(funcion, formula=formula, data=datos))
modelo = modela(p)
## Distribution not specified, assuming gaussian ...
mae(predict(modelo, newdata=datosTest, n.trees=p$n.trees)-datosTest[,1])
## [1] 10.35