Esta entrada aprece en el Blog Punto de Acumulación
En esta entrada vamos a tratar otro tema habitual en el análisis de datos: el ajuste de los parámetros de los algoritmos de aprendizaje. Para los detalles teóricos, se puede consultar model training and tuning del paquete Caret.
(NOTA: En esta entrada se hace una búsqueda bastante exhaustiva en el espacio de parámetros y puede tardar bastante tiempo. Para una primera aproximación, se podría hacer una búsqueda con menos valores.)
Aquí iremos directamente al grano y ajustaremos cuatro de los parámetros de gbm: interaction.depth, n.trees, shrinkage y n.minobsinnode. Se va a realizar sobre un conjunto de datos de ejemplo, algae bloom o reproducción de algas, que nos descargaremos directamente de la web y que se pueden encontrar aqui.
La manera más habitual de ajustar estos parámetros es probar todas las combinaciones de valores de parámetros en lo que se llama una búsqueda en rejilla o grid search. Para cada combinacion de parámetros, se realiza una validacion cruzada. En este caso se han usado 5 folds, aunque si el tamaño de los datos es muy grande, puede ser conveniente usar sólo 2 folds, o directamente entrenamiento / validación. Podríamos realizar dicha búsqueda en rejilla mediante un cuadruple bucle anidado, pero en esta entrada usaremos algunos elementos más avanzados. En concreto:
En lugar de bucles usaremos iteradores, mediante las librerías iterators e itertools, que son elementos que permiten iterar a lo largo de una estructura, o ir generando nuevos valores según se vayan necesitando. En este caso, como la búsqueda en rejilla requiere probar todas las combinaciones de valores de parámetros, usaremos iteradores para ir generando cada combinación.
Para aprovechar máquinas multi-core, usaremos el paquete doMC junto con foreach, para que cada combinacion de parámetros se ejecute en un core distinto. Esto funciona en Linux, pero no en Windows. Para Windows basta con cambiar %dopar% por %do% y el código se ejecutará en un único core. foreach se encarga de ir recorriendo la lista de combinaciones de parámetros.
En mi experiencia, cuando se manejan conjuntos de datos grandes y se usa foreach y multi-core, algunas iteraciones fallan y no generan resultado, aunque la ejecucion continua. También puede ocurrir que se vaya la corriente en medio de una ejecucion, etc. Para evitar repetir la ejecucion del programa para todas las combinaciones de parámetros, el código va salvando en un fichero los resultados de cada iteración, mediante la funcion catappend. Si alguna iteracion falló, basta con volver a ejecutar el código y sólo se repetirán aquellas iteraciones (combinaciones de parámetros) que no se hayan realizado en la ejecucion anterior.
Además de almacenarlos en un fichero, los resultados de cada iteración se guardan en una data.table. Data.table es una estructura similar a los data.frame, pero de acceso más rápido, especialmente si se usan claves. Por ejemplo, para saber si una combinación de parámetros p se ha ejecutado ya, bastan con comprobar si nrow(parametros[as.data.table(p), nomatch=0])==0. En esta última expresión se está haciendo un join entre p y parametros.
Además de almacenar el error en entrenamiento y en test, y el tiempo que cuesta construir el modelo para cada combinación de parámetros y cada fold, tambien se almacena el modelo en sí mediante saveRDS. Esta instrucción permite guardar estructuras R en un fichero binario. Podemos volver a cargar el modelo y usarlo para hacer nuevas predicciones. De todas maneras, en general no necesitaremos estos modelos, por lo que se recomienda comentar la línea con saveRDS para ahorrar espacio en disco y tiempo de proceso.
Las particiones de validación cruzada se realizan mediante la funcion createFolds del paquete de aprendizaje automático Caret. De hecho, el propio Caret puede realizar ajuste de parámetros de manera automática, pero debido a que tiene ciertas limitaciones, en esta entrada vamos a programar el ajuste de parámetros desde cero.
Para que la ejecución sea reproducible, la variable semilla contiene la semilla aleatoria inicial, que se usa tanto al dividir el conjunto de datos en folds, como al ejecutar el algoritmo de aprendizaje automatico (en este caso gbm)
La función CurryL del paquete functional permite transformar una función en otra, tras fijar ciertos parámetros. Por ejemplo, la expresión model = CurryL(gbm, formula=a1~., data=entrenamiento) crea una nueva función model a partir de gbm, pero donde ya se han fijado los parámetros formula y data.
La función splat del paquete plyr permite transformar una función en otra que toma los argumentos de un data.frame. Por ejemplo, si gbm requiere los parámetros n.trees y shrinkage, y disponemos de un data.frame asi p=data.frame(n.trees=1000, shrinkage=0.1), podemos llamar a la función gbm de esta manera: splat(gbm)(p) en lugar de gbm(n.trees=1000, shrinkage=0.1). La función do.call permite llamar a una función cuyos argumentos se encuentran en una lista, en lugar de en un data.frame, pero no la hemos usado aqui.
El parámetro n.trees de gbm es especial en el sentido de que sólo merece la pena crear un modelo con el valor máximo de ese parámetro. Es decir, no merece la pena construir un modelo con 100 árboles, otro con 500, y otro con 1000, sino que es mejor construir el de 1000, y después pasarle los datos de validación usando los primeros 100 árboles, los primeros 500, y los primeros 1000. Esto es lo que hace el bucle interno que aparece en foreach. De esta manera ahorramos tiempo, porque construir un modelo es un proceso muy costoso en tiempo, mientras que probar un modelo con mas o menos árboles es bastante rápido. Por ello, nos podremos permitir el lujo de probar muchos valores de este parámetro.
Una aclaración sobre la nomenclatura. Los datos de entrenamiento son los que se usan para construir el modelo gbm. Los de validación son los que se usan para evaluar cada combinación de parámetros en cada punto de la rejilla. Dado que se usa validación cruzada, los datos de entrenamiento y validación cambian para cada fold. una vez que se ha seleccionado la mejor combinación de parámetros, se construirá un único modelo con todos los datos y se evaluará con los datos de test, que hemos guardado hasta el final para este propósito.
El código tiene algunas características avanzadas de R (uso de data.table, iteradores, programación funcional, etc), por lo que puede ser difícil de seguir, pero se puede aprender bastante intentando entenderlo.
A continuación se muestra el código:
options(width=500)
catappend = function(...,file="",header=FALSE) {
fe <- file.exists(file)
if(header && fe) warning(paste("catappend: File ", file," exists\n",immediate.=TRUE))
# if(header) cat(paste("#",Sys.time()), append=TRUE,fill=TRUE,file=file)
if(!header || (header && !fe)) cat(..., append=TRUE,fill=TRUE,file=file)
}
# Function that returns Root Mean Squared Error
rmse = function(error) sqrt(mean(error^2))
# Function that returns Mean Absolute Error
mae = function(error) mean(abs(error))
library(plyr) # para la función Splat
library(data.table) # alternativa a data.frame
library(iterators) # iteradores
library(itertools)
library("foreach")
library("functional") # programacion funcional
library(gbm) # gradient boosting
library(caret) # Mineria de datos
library("doMC") # multicore
registerDoMC(4) # Usa 4 cores
semilla = 1 # semilla aleatoria
rango_nFolds = 1:5 # Validación cruzada de 5 folds
# rango_n.trees = c(100, 1000, 5000, 10000)
# rangos_parametros = list(c(1, 2, 4, 6, 8, 10),
# c(0.001, 0.01, 0.05, 0.1, 0.5),
# c(2, 5, 10, 15, 20),
# rango_nFolds
# )
rango_n.trees = c(100, seq(500,15000, 500))
rangos_parametros = list(c(1, 2, 4, 6, 8, 10, 12, 14),
c(0.0001, 0.0005, 0.001, 0.01, 0.05, 0.1, 0.5),
c(2, 3, 4, 5, 10, 15, 20, 25, 30),
rango_nFolds
)
nombres_parametros = c("interaction.depth", "shrinkage", "n.minobsinnode", "nFold")
names(rangos_parametros) = nombres_parametros
claves = c(nombres_parametros, "n.trees")
# Cargamos los datos para hacer las validaciones cruzadas en cada punto de la rejilla. Al final cargaremos los de test para hacer el test final.
load(url("http://www.dcc.fc.up.pt/~ltorgo/DataMiningWithR/DataSets/algae.RData"))
# La variable de salida (a1) la ponemos como primera columna para tenerla fácilmente accesible. De esa manera, datos[,1] es la salida y algae_a1[,-1] son las entradas. Este problema tenia múltiples salidas (a1 hasta a7). Descartamos todas menos a1.
datos = subset(algae, select=c(a1, season:Chla))
funcion = gbm
formula = a1~.
# Creamos los folds de validacion cruzada con librería Caret
set.seed(semilla)
folds = createFolds(datos[,1], k = 5, list=FALSE)
# En este fichero iremos almacenando los resultados de entrenamiento y validación para cada combinación de parámetros y cada fold
nf = paste0("parametrosGBM-",semilla,".txt")
# Si el fichero en el que se guardan los resultados existe, cárgalo en memoria en una data.table y haz que los nombres de los parámetros sean las claves para buscar en dicha data.table. Si no, simplemente crea el fichero de resultados únicamente con la cabecera de los nombres de los atributos.
if(file.exists(nf)) {
# Usamos unique en caso de que por la razón que sea, haya duplicados. Damos por supuesto que como se ha fijado la semilla aleatoria, los duplicados siempre van a dar el mismo resultado y se pueden elegir arbitrariamente
parametros = fread(nf)
setkeyv(parametros, claves)
parametros = unique(parametros)
} else {
catappend(sprintf("%s Train Validacion Time", paste0(claves, collapse=" ")), file=nf, header=TRUE)
}
# Creamos el iterador. product genera todas las posibles combinaciones de parámetros de los rangos que hemos puesto. Es similar a expand.grid
it <- ihasNext(splat(product)(rangos_parametros))
# Por fin comienza el bucle principal donde se exploran todas las posibles combinaciones de parámetros. Aquellas combinaciones que ya hayan sido evaluadas y que estan almacenadas en el data.table parametros, no se ejecutarán. Esto se consigue mágicamente comprobando que nrow(parametros[as.data.table(p)])==0)
foreach(p = it) %:% when(!exists("parametros") || nrow(parametros)==0 || nrow(parametros[as.data.table(p), nomatch=0])==0) %do% {
entrenamiento = datos[folds != p$nFold,]
validacion = datos[folds == p$nFold,]
# Aquí definimos la función que lanza el algoritmo de aprendizaje
modela = splat(CurryL(funcion, formula=formula, data=entrenamiento, n.trees=max(rango_n.trees)))
set.seed(semilla)
comienzo = Sys.time()
# Usa todos los parámetros en p, menos el último, porque es el número de fold
modelo = modela(head(p,-1))
fin = Sys.time()
# Opcional: salvamos cada modelo por si tenemos que usarlo en el futuro. Hay que tener cuidado porque en algunos casos (como gbm), R incluye mucha informacion en el modelo. En el caso de gbm, se incluyen todos los datos de entrenamiento, con lo que el fichero puede llegar a ser bastante grande.
# saveRDS(modelo, file=paste0("GBMmodel-", semilla, "-", paste0(unlist(p), collapse="-")))
# Una vez construido el modelo, con el máximo de número de árboles, lo probamos con los valores de número de árboles menores al máximo.
for(n.trees in rango_n.trees){
salidasTrain = predict(modelo, newdata = entrenamiento, n.trees=n.trees)
salidasValidacion = predict(modelo, newdata = validacion, n.trees=n.trees)
# Salvamos resultados en el fichero
catappend(sprintf("%s %d %f %f %f",
paste0(p, collapse = " "),
n.trees,
mae(salidasTrain-entrenamiento[[1]]),
mae(salidasValidacion-validacion[[1]]),
as.numeric(fin-comienzo)
),
file=nf)
}
}
# Si cada combinación de parámetros y folds tiene menos ejecuciones que número de diferentes valores de árboles
parametros = fread(nf)
setkeyv(parametros, claves)
parametros = unique(parametros)
if(nrow(parametros)<length(rango_n.trees)*prod(sapply(rangos_parametros,length))) cat("Algunas combinaciones de parámetros han fallado. Se recomienda volver a ejecutar el código")
El bucle anterior ha probado todas las posibles combinaciones de parámetros de gbm. Caso de que alguna combinación hubiera fallado en su ejecución (o se fue la corriente), basta con volver a ejecutar el código, el cual sólo ejecutará aquellas combinaciones de parámetros que queden por realizar. Tenemos los resultados para cada uno de los 5 folds, como se puede ver a continuación:
head(parametros,20)
## interaction.depth shrinkage n.minobsinnode nFold n.trees Train Validacion Time
## 1: 1 1e-04 2 1 100 17.15 14.81 0.4847
## 2: 1 1e-04 2 1 500 16.85 14.58 0.4847
## 3: 1 1e-04 2 1 1000 16.49 14.28 0.4847
## 4: 1 1e-04 2 1 1500 16.17 14.00 0.4847
## 5: 1 1e-04 2 1 2000 15.86 13.72 0.4847
## 6: 1 1e-04 2 1 2500 15.57 13.46 0.4847
## 7: 1 1e-04 2 1 3000 15.29 13.21 0.4847
## 8: 1 1e-04 2 1 3500 15.03 12.97 0.4847
## 9: 1 1e-04 2 1 4000 14.79 12.73 0.4847
## 10: 1 1e-04 2 1 4500 14.56 12.52 0.4847
## 11: 1 1e-04 2 1 5000 14.34 12.33 0.4847
## 12: 1 1e-04 2 1 5500 14.12 12.14 0.4847
## 13: 1 1e-04 2 1 6000 13.92 11.97 0.4847
## 14: 1 1e-04 2 1 6500 13.73 11.79 0.4847
## 15: 1 1e-04 2 1 7000 13.54 11.63 0.4847
## 16: 1 1e-04 2 1 7500 13.36 11.47 0.4847
## 17: 1 1e-04 2 1 8000 13.19 11.30 0.4847
## 18: 1 1e-04 2 1 8500 13.04 11.16 0.4847
## 19: 1 1e-04 2 1 9000 12.89 11.01 0.4847
## 20: 1 1e-04 2 1 9500 12.75 10.87 0.4847
Ahora calcularemos las medias de Train, Validacion y Time para los 5 folds asi:
# Hacemos la media de validacion cruzada
parametros = parametros[, list(Train=mean(Train), Validacion=mean(Validacion), Time=mean(Time)), by=eval(claves[claves!="nFold"])]
head(parametros, 20)
## interaction.depth shrinkage n.minobsinnode n.trees Train Validacion Time
## 1: 1 1e-04 2 100 16.54 16.57 0.5627
## 2: 1 1e-04 2 500 16.25 16.32 0.5627
## 3: 1 1e-04 2 1000 15.90 16.03 0.5627
## 4: 1 1e-04 2 1500 15.58 15.74 0.5627
## 5: 1 1e-04 2 2000 15.29 15.48 0.5627
## 6: 1 1e-04 2 2500 15.00 15.24 0.5627
## 7: 1 1e-04 2 3000 14.74 15.01 0.5627
## 8: 1 1e-04 2 3500 14.48 14.79 0.5627
## 9: 1 1e-04 2 4000 14.25 14.57 0.5627
## 10: 1 1e-04 2 4500 14.03 14.37 0.5627
## 11: 1 1e-04 2 5000 13.81 14.18 0.5627
## 12: 1 1e-04 2 5500 13.61 14.00 0.5627
## 13: 1 1e-04 2 6000 13.42 13.82 0.5627
## 14: 1 1e-04 2 6500 13.23 13.65 0.5627
## 15: 1 1e-04 2 7000 13.06 13.49 0.5627
## 16: 1 1e-04 2 7500 12.89 13.35 0.5627
## 17: 1 1e-04 2 8000 12.73 13.20 0.5627
## 18: 1 1e-04 2 8500 12.58 13.07 0.5627
## 19: 1 1e-04 2 9000 12.43 12.95 0.5627
## 20: 1 1e-04 2 9500 12.29 12.83 0.5627
La mejor combinación de parámetros corresponde a:
(p = parametros[which.min(Validacion)])
## interaction.depth shrinkage n.minobsinnode n.trees Train Validacion Time
## 1: 4 0.05 10 100 7.227 10.51 1.645
p = p[,1:(ncol(p)-3), with=FALSE] # Quitamos los valores de Train, Validación y Time que no son parámetros realmente
Ahora construimos el modelo con todos los datos y los mejores parámetros y lo probamos con los datos de test, que nos traemos de la web. El Mean Absolute Error con los datos de test se puede ver al final. No es muy distinto a lo que nos salía antes en validación.
load(url("http://www.dcc.fc.up.pt/~ltorgo/DataMiningWithR/DataSets/testAlgae.RData"))
load(url("http://www.dcc.fc.up.pt/~ltorgo/DataMiningWithR/DataSets/algaeSols.RData"))
datosTest = cbind(a1=algae.sols$a1, test.algae)
modela = splat(CurryL(funcion, formula=formula, data=datos))
modelo = modela(p)
## Distribution not specified, assuming gaussian ...
mae(predict(modelo, newdata=datosTest, n.trees=p$n.trees)-datosTest[,1])
## [1] 10.35