En este documento se ejemplifica la realización de árboles de regresión y clasificación haciendo uso de la librería tree. Los ejemplos aquí mostrados no son de nuestra autoría, sino que están basados en análisis hechos por Joaquín Amat Rodrigo y João Neto. Puede recurrir a los documentos de ambos autores a partir de los siguientes enlaces:

Árboles de regresión

Instalación y carga de paquetes

En primer lugar, se realiza la instalación y/o carga de paquetes en RStudio, tanto para análisis de los datos como para la aplicación de árboles de regresión:

# Librerías para análisis y tratamiento de los datos

#install.packages("skimr")

library(MASS)
library(dplyr)
library(tidyr)
library(skimr)


# Librerías para gráficos de datos

library(ggplot2)
library(ggpubr)


# Librería para implementar árboles de regresión y clasificación

#install.packages("tree")

library(tree)

Carga y observación de la base de datos

El conjunto de datos a estudiar se denomina “Boston” y contiene información de precios de viviendas en la ciudad estadounidense de la misma denominación; además, presenta información socioeconómica según el suburbio en el que se encuentre la vivienda. El archivo de datos se encuentra en el paquete MASS.

data("Boston") # Carga del dataset
head(Boston) # Muestra de algunos datos
##      crim zn indus chas   nox    rm  age    dis rad tax ptratio  black lstat
## 1 0.00632 18  2.31    0 0.538 6.575 65.2 4.0900   1 296    15.3 396.90  4.98
## 2 0.02731  0  7.07    0 0.469 6.421 78.9 4.9671   2 242    17.8 396.90  9.14
## 3 0.02729  0  7.07    0 0.469 7.185 61.1 4.9671   2 242    17.8 392.83  4.03
## 4 0.03237  0  2.18    0 0.458 6.998 45.8 6.0622   3 222    18.7 394.63  2.94
## 5 0.06905  0  2.18    0 0.458 7.147 54.2 6.0622   3 222    18.7 396.90  5.33
## 6 0.02985  0  2.18    0 0.458 6.430 58.7 6.0622   3 222    18.7 394.12  5.21
##   medv
## 1 24.0
## 2 21.6
## 3 34.7
## 4 33.4
## 5 36.2
## 6 28.7

El objetivo en este caso es ajustar un modelo de regresión basado en árboles que permita predecir el precio medio de una vivienda (variable medv) en dicha ciudad basándonos en las características disponibles.

En este sentido es de utilidad realizar en primera instancia un resumen descriptivo del archivo de datos:

skim(Boston) # Resumen del archivo de datos
Data summary
Name Boston
Number of rows 506
Number of columns 14
_______________________
Column type frequency:
numeric 14
________________________
Group variables None

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
crim 0 1 3.61 8.60 0.01 0.08 0.26 3.68 88.98 ▇▁▁▁▁
zn 0 1 11.36 23.32 0.00 0.00 0.00 12.50 100.00 ▇▁▁▁▁
indus 0 1 11.14 6.86 0.46 5.19 9.69 18.10 27.74 ▇▆▁▇▁
chas 0 1 0.07 0.25 0.00 0.00 0.00 0.00 1.00 ▇▁▁▁▁
nox 0 1 0.55 0.12 0.38 0.45 0.54 0.62 0.87 ▇▇▆▅▁
rm 0 1 6.28 0.70 3.56 5.89 6.21 6.62 8.78 ▁▂▇▂▁
age 0 1 68.57 28.15 2.90 45.02 77.50 94.07 100.00 ▂▂▂▃▇
dis 0 1 3.80 2.11 1.13 2.10 3.21 5.19 12.13 ▇▅▂▁▁
rad 0 1 9.55 8.71 1.00 4.00 5.00 24.00 24.00 ▇▂▁▁▃
tax 0 1 408.24 168.54 187.00 279.00 330.00 666.00 711.00 ▇▇▃▁▇
ptratio 0 1 18.46 2.16 12.60 17.40 19.05 20.20 22.00 ▁▃▅▅▇
black 0 1 356.67 91.29 0.32 375.38 391.44 396.22 396.90 ▁▁▁▁▇
lstat 0 1 12.65 7.14 1.73 6.95 11.36 16.96 37.97 ▇▇▅▂▁
medv 0 1 22.53 9.20 5.00 17.02 21.20 25.00 50.00 ▂▇▅▁▁

Ajuste del modelo de regresión basado en árboles de decisión

Para el ajuste del modelo se hará uso de la función tree disponible en el paquete homónimo. Por defecto, la función ajusta un modelo de regresión o de clasificación basado en el tipo de variable de respuesta, ya sea cuantitativa o cualitativa (numeric o factor, de acuerdo con el lenguaje propio de R). La variable medv es cuantitativa, por lo que el modelo ajustado será de regresión.

Sin embargo, es necesario realizar de manera previa una división del total de datos en datos de entrenamiento y datos de prueba:

set.seed(500) # Establecimiento de la semilla
entrenamiento <- sample(1:nrow(Boston), size = nrow(Boston)/2) # Muestra de datos para entrenamiento
datos_entrenamiento <- Boston[entrenamiento,] # Datos de entrenamiento
datos_prueba <- Boston[-entrenamiento,] # Datos de prueba

Ya con los datos de entrenamiento es posible entonces crear el modelo de regresión:

set.seed(500) # Establecimiento de la semilla
arbol_regresion <- tree::tree(
        formula = medv ~ ., # Variables de respuesta y predictoras
        data = datos_entrenamiento, # Datos a utilizar para la creación del modelo
        split = "deviance", # Criterio de división
        mincut = 20, # Número mínimo de observaciones para que se produzca la división
        minsize = 50 # Número mínimo de observaciones para que un nodo pueda ramificarse
)

summary(arbol_regresion) # Resumen del árbol de regresión
## 
## Regression tree:
## tree::tree(formula = medv ~ ., data = datos_entrenamiento, split = "deviance", 
##     mincut = 20, minsize = 50)
## Variables actually used in tree construction:
## [1] "rm"    "lstat" "crim" 
## Number of terminal nodes:  6 
## Residual mean deviance:  18.46 = 4559 / 247 
## Distribution of residuals:
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## -20.82000  -1.97200   0.08049   0.00000   2.22800  23.83000

El resumen anterior nos permite identificar algunos resultados de interés: las variables utilizadas para la construcción del árbol son rm (número promedio de habitaciones por vivienda), lstat (menor estatus porcentual de la población) y crim (tasa de criminalidad per cápita); hay un total de seis nodos terminales u hojas del árbol y la desvianza residual media (definida como el cociente entre la SSR y el número de observaciones menos la cantidad de nodos terminales) es relativamente baja, lo que indica que el ajuste del árbol a los datos de entrenamiento es bueno.

Ahora bien, la visualización del árbol de decisión establecido mediante la función tree se programa de la siguiente manera:

par(mar = c(1,1,1,1)) # Márgenes del gráfico
plot(x = arbol_regresion, type = "proportional") # Árbol de decisión con ramas proporcionales a la impureza (heterogeneidad) de las hojas
text(x = arbol_regresion, splits = TRUE, pretty = 0, cex = 0.8, col = "firebrick") # Etiquetas del árbol

Poda por coste-complejidad del árbol

Con el objetivo de obtener predicciones significativamente buenas, contribuyendo a su vez a la disminución de la variabilidad del modelo y evitando un sobreajuste del modelo, es necesario que el árbol de decisión pase por un proceso de podado, en este caso el de coste-complejidad. Para esto, se aplica primeramente el proceso de validación cruzada sobre el árbol de tamaño completo:

# Se hace crecer el árbol tanto como sea posible

arbol_regresion <- tree( 
                    formula = medv ~ .,
                    data    = datos_entrenamiento,
                    split   = "deviance",
                    mincut  = 1,
                    minsize = 2,
                    mindev  = 0
                  )

set.seed(500) # Establecimiendo de la semilla
cv_arbol <- cv.tree(arbol_regresion, K = 5) # Aplicación del proceso de validación cruzada

cv_arbol # Muestra de resultados por validación cruzada
## $size
##   [1] 217 214 211 205 203 202 194 192 190 186 185 183 179 177 175 174 170 169
##  [19] 167 164 161 160 158 155 154 153 151 150 149 147 146 144 143 142 141 140
##  [37] 139 138 136 134 133 132 131 130 129 128 127 126 125 124 123 122 120 119
##  [55] 115 114 113 112 111 110 109 108 107 106 105 104 103 102 101 100  99  97
##  [73]  95  94  93  92  90  89  88  86  85  84  83  82  81  79  78  77  76  75
##  [91]  74  73  72  71  70  69  68  67  66  65  64  62  61  60  59  58  57  56
## [109]  55  54  53  52  51  50  48  47  46  45  44  43  41  40  39  38  37  36
## [127]  35  34  33  32  31  29  28  27  25  24  23  21  20  19  18  17  16  15
## [145]  14  13  12  11  10   9   8   7   6   5   4   3   2   1
## 
## $dev
##   [1]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##   [8]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [15]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [22]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [29]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [36]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [43]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [50]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [57]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [64]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [71]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [78]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [85]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [92]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
##  [99]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
## [106]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
## [113]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
## [120]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
## [127]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
## [134]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
## [141]  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889  6158.889
## [148]  6158.889  6158.889  6158.889  7870.794  8020.069  8620.205  8620.205
## [155]  8620.205 11196.544 12628.095 21475.059
## 
## $k
##   [1]         -Inf 2.666667e-02 4.166667e-02 4.500000e-02 6.000000e-02
##   [6] 6.750000e-02 8.000000e-02 8.166667e-02 1.066667e-01 1.250000e-01
##  [11] 1.350000e-01 1.666667e-01 1.800000e-01 2.016667e-01 2.400000e-01
##  [16] 2.408333e-01 2.450000e-01 2.700000e-01 2.816667e-01 3.200000e-01
##  [21] 3.266667e-01 3.333333e-01 3.750000e-01 4.050000e-01 4.408333e-01
##  [26] 4.816667e-01 5.000000e-01 5.633333e-01 6.050000e-01 6.533333e-01
##  [31] 6.666667e-01 7.200000e-01 7.350000e-01 7.605000e-01 7.810714e-01
##  [36] 8.066667e-01 8.100000e-01 8.533333e-01 9.633333e-01 9.800000e-01
##  [41] 1.000000e+00 1.037143e+00 1.041667e+00 1.080000e+00 1.125000e+00
##  [46] 1.140833e+00 1.142857e+00 1.203333e+00 1.210000e+00 1.215000e+00
##  [51] 1.333333e+00 1.400833e+00 1.408333e+00 1.440000e+00 1.445000e+00
##  [56] 1.466786e+00 1.541333e+00 1.562500e+00 1.822500e+00 1.860500e+00
##  [61] 1.922000e+00 2.000000e+00 2.205000e+00 2.253333e+00 2.312000e+00
##  [66] 2.666667e+00 2.749383e+00 2.821333e+00 2.880000e+00 3.000000e+00
##  [71] 3.226667e+00 3.313500e+00 3.341250e+00 3.526667e+00 3.610000e+00
##  [76] 3.698000e+00 3.745333e+00 3.920000e+00 3.967500e+00 4.041667e+00
##  [81] 4.573333e+00 4.702222e+00 4.805000e+00 4.836571e+00 5.760000e+00
##  [86] 5.772083e+00 5.801667e+00 5.832000e+00 6.002500e+00 6.125000e+00
##  [91] 6.503712e+00 6.600833e+00 6.816333e+00 7.105333e+00 7.290000e+00
##  [96] 7.475556e+00 7.525051e+00 7.800238e+00 7.935000e+00 8.171905e+00
## [101] 9.375000e+00 1.036923e+01 1.056250e+01 1.140050e+01 1.148167e+01
## [106] 1.162076e+01 1.276900e+01 1.349944e+01 1.352000e+01 1.400022e+01
## [111] 1.405727e+01 1.410667e+01 1.430519e+01 1.431125e+01 1.507250e+01
## [116] 1.837500e+01 1.867778e+01 2.090006e+01 2.145606e+01 2.153470e+01
## [121] 2.172524e+01 2.204167e+01 2.236810e+01 2.281500e+01 2.298640e+01
## [126] 2.320714e+01 2.398050e+01 2.643431e+01 2.670083e+01 2.738000e+01
## [131] 3.118674e+01 3.383274e+01 3.629186e+01 3.691778e+01 4.156365e+01
## [136] 4.437600e+01 5.031627e+01 5.076055e+01 5.680278e+01 6.106133e+01
## [141] 6.765019e+01 8.433514e+01 9.035627e+01 9.261000e+01 9.506913e+01
## [146] 1.175078e+02 1.298996e+02 1.421652e+02 1.539922e+02 1.592155e+02
## [151] 4.314780e+02 5.104021e+02 6.195786e+02 7.192168e+02 7.664787e+02
## [156] 1.958970e+03 2.575767e+03 1.124657e+04
## 
## $method
## [1] "deviance"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

Así, puede observarse entonces el número de nodos terminales por árbol (size), el error por validación cruzada de cada árbol (dev), el rango de valores de penalización y el método empleado para la construcción del árbol.

size_optimo <- rev(cv_arbol$size)[which.min(rev(cv_arbol$dev))] # Búsqueda del tamaño óptimo
paste("Tamaño óptimo del número de nodos terminales:", size_optimo)
## [1] "Tamaño óptimo del número de nodos terminales: 9"
# Dataset con los valores obtenidos por validación cruzada

resultados_cv <- data.frame(
                   n_nodos  = cv_arbol$size,
                   deviance = cv_arbol$dev,
                   alpha    = cv_arbol$k
                 )


# Gráfico del decremento del error en relación con el número de nodos terminales

p1 <- ggplot(data = resultados_cv, aes(x = n_nodos, y = deviance)) +
      geom_line() + 
      geom_point() +
      geom_vline(xintercept = size_optimo, color = "red") +
      labs(title = "Error vs tamaño del árbol") +
      theme_bw() 
  

# Gráfico que contrasta los valores del parámetro de ajuste y el error obtenido con ellos

p2 <- ggplot(data = resultados_cv, aes(x = alpha, y = deviance)) +
      geom_line() + 
      geom_point() +
      labs(title = "Error vs penalización alpha") +
      theme_bw() 

ggarrange(p1, p2)

Con esto en mente, se procede entonces a graficar el mejor árbol de decisión encontrado:

# Poda del árbol de decisión por coste-complejidad

arbol_final <- prune.tree(
                  tree = arbol_regresion, # Árbol de tamaño completo
                  best = size_optimo # Número de nodos terminales óptimo
               )


# Gráfico del árbol de decisión posterior al proceso de podado

par(mar = c(1,1,1,1))
plot(x = arbol_final, type = "proportional")
text(x = arbol_final, splits = TRUE, pretty = 0, cex = 0.8, col = "firebrick")

Predicción y evaluación del modelo

Finalmente, se realiza la evaluación de la capacidad predictiva del modelo haciendo uso de los datos de prueba:

# Predicciones y error de predicción del modelo de tamaño completo

predicciones <- predict(arbol_regresion, newdata = datos_prueba)
test_rmse    <- sqrt(mean((predicciones - datos_prueba$medv)^2))
paste("Error de test del árbol inicial:", round(test_rmse,2))
## [1] "Error de test del árbol inicial: 4.98"
# Predicciones y error de predicción del modelo de tamaño completo

predicciones <- predict(arbol_final, newdata = datos_prueba)
test_rmse    <- sqrt(mean((predicciones - datos_prueba$medv)^2))
paste("Error de test del árbol final:", round(test_rmse,2))
## [1] "Error de test del árbol final: 4.92"

Se puede ver entonces que el modelo final presenta un menor error, en comparación con el modelo en su totalidad, lo que quiere decir que las predicciones serán algo mejores con el árbol podado, alejándose en promedio 4.92 unidades (4920 dólares) del valor original.

Árboles de clasificación

Para este ejemplo se trabajó con la base de datos Iris disponible en la librería datasets. La base se encuentra compuesta por 150 observaciones de flores de la planta iris, la variables o atributos que se miden de cada flor son:

Para llegar a cabo la clasificación de árbol se procederá entonces a predecir el tipo de flor a partir de las variables sepal.width y petal.width. Entonces, lo primero que se debe realizar es la elección de los datos de entrenamiento, para este caso al considerar la validación cruzada tomaremos un \(\alpha\) de 0.7, luego extraemos una muestra proporcional a \(\alpha\) del conjunto de datos Iris y los datos restantes son los de prueba.

set.seed(101)
alpha<- 0.7 #Porcentaje del conjunto de entrenamiento
inTrain<- sample(1:nrow(iris), alpha * nrow(iris)) #Muestra de los datos (ent.)
train.set <- iris[inTrain,] #Extracción de los datos ent. de la base Iris.
test.set  <- iris[-inTrain,] #Extracción de los datos de prueba.


#Clasificación de árbol
tree.model <- tree(Species ~ Sepal.Width + Petal.Width, data=train.set)
tree.model
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 105 230.50 virginica ( 0.31429 0.33333 0.35238 )  
##    2) Petal.Width < 0.8 33   0.00 setosa ( 1.00000 0.00000 0.00000 ) *
##    3) Petal.Width > 0.8 72  99.76 virginica ( 0.00000 0.48611 0.51389 )  
##      6) Petal.Width < 1.7 38  20.99 versicolor ( 0.00000 0.92105 0.07895 )  
##       12) Petal.Width < 1.35 21   0.00 versicolor ( 0.00000 1.00000 0.00000 ) *
##       13) Petal.Width > 1.35 17  15.84 versicolor ( 0.00000 0.82353 0.17647 ) *
##      7) Petal.Width > 1.7 34   0.00 virginica ( 0.00000 0.00000 1.00000 ) *
summary(tree.model)
## 
## Classification tree:
## tree(formula = Species ~ Sepal.Width + Petal.Width, data = train.set)
## Variables actually used in tree construction:
## [1] "Petal.Width"
## Number of terminal nodes:  4 
## Residual mean deviance:  0.1569 = 15.84 / 101 
## Misclassification error rate: 0.02857 = 3 / 105

Lo siguiente a mostrar es las probabilidades con las cuales es probable que una observación se clasifique en cierta categoría.

#Predicciones de los datos de prueba a partir del algoritmo
prediccion.obs <- predict(tree.model, test.set) # probabilidad por cada clase
head(prediccion.obs) #muestra las primeras 6 probabilidades
##    setosa versicolor virginica
## 1       1          0         0
## 5       1          0         0
## 11      1          0         0
## 13      1          0         0
## 16      1          0         0
## 23      1          0         0

Por ejemplo, para la flor 5 con probabilidad 1 se espera que se clasifique como setosa. Del mismo modo podríamos observar para los otros datos de prueba.

Lo siguiente que se realizó fue una tabla para resumir de los datos de prueba a qué tipo de planta fueron clasificados a partir del algoritmo.

#Identificación de qué tan bien se realizó la clasificación con el método
maxim <- function(arg) { #función para dentificar la máxima probabilidad
    return(which(arg == max(arg)))
}
idx <- apply(prediccion.obs, c(1), maxim) #Identificar la especie a la que
prediccion <- c('setosa', 'versicolor', 'virginica')[idx] #pertenece
table(prediccion, test.set$Species)
##             
## prediccion   setosa versicolor virginica
##   setosa         17          0         0
##   versicolor      0         13         1
##   virginica       0          2        12

Se procede entonces a ilustrar el árbol de clasificación:

#Gráfico de la clasificación de árbol
plot(tree.model)
text(tree.model)

Otra forma en que se puede ilustrar al conjunto de datos es:

#Otra forma de graficar
library(ggplot2)
library(tree)
ggplot(iris,
       aes(Petal.Width, Sepal.Width, color=Species)) +
  geom_point() + theme_bw()+
  gg.partition.tree(tree(Species ~ Sepal.Width + Petal.Width, data=iris),
                    label="Species", color = "black")

Es posible entonces podar el árbol para evitar el sobreajuste. La función prune.tree() permite elegir cuantas hojas queremos que tenga el árbol y devuelve el mejor árbol con el tamaño necesario.

#Podación del árbol.
poda.arbol <- prune.tree(tree.model, best=4)
plot(poda.arbol)
text(poda.arbol)