Árboles

Árboles de Decisión

Ajustando Árboles de Clasificación

La librería tree se usa para construir árboles de clasificación y regresión.

library(tree)
library(ISLR2)
library(tidyverse)
library(patchwork)
library(caret)
data(Carseats)

Primero usamos árboles de clasificación para analizar el conjunto de datos Carseats. Buscamos ajustar un árbol de clasificación para predecir ventas altas usando todas las variables (excepto Sales). Para ello, definimos la columna High si las ventas son mayores a 8 (miles):

Carseats <- Carseats |> 
  mutate(High = factor(ifelse(Sales <= 8, "No", "Yes")))

Ahora, usemos la función tree() cuya sintaxis es bastante similar a la de la función lm().

tree.carseats <- tree(High ~ . - Sales, Carseats)

La función summary() lista las variables que se usan como nodos internos en el árbol, el número de nodos terminales y la tasa de error (de entrenamiento).

summary(tree.carseats)

Classification tree:
tree(formula = High ~ . - Sales, data = Carseats)
Variables actually used in tree construction:
[1] "ShelveLoc"   "Price"       "Income"      "CompPrice"   "Population" 
[6] "Advertising" "Age"         "US"         
Number of terminal nodes:  27 
Residual mean deviance:  0.4575 = 170.7 / 373 
Misclassification error rate: 0.09 = 36 / 400 

Vemos que la tasa de error de entrenamiento es del \(9\%\). Para árboles de clasificación, la desviación (deviance) reportada en la salida de summary() está dada por \[ -2 \sum*m* \sum*k n*{mk} \log \hat{p}{mk},\] donde \(n_{mk}\) es el número de observaciones en el \(m\)-ésimo nodo terminal que pertenecen a la \(k\)-ésima clase. Esto está cercanamente relacionado con la entropía. Una desviación pequeña indica un árbol que proporciona un buen ajuste a los datos (de entrenamiento). La desviación media residual reportada es simplemente la desviación dividida por \(n-|{T}_0|\), que en este caso es \(400-27=373\).

Una de las propiedades más atractivas de los árboles es que pueden ser mostrados gráficamente. Usamos la función plot() para mostrar la estructura del árbol, y la función text() para mostrar las etiquetas de los nodos. El argumento pretty = 0 instruye a R para que incluya los nombres de las categorías para cualquier predictor cualitativo, en lugar de simplemente mostrar una letra para cada categoría.

plot(tree.carseats)
text(tree.carseats, pretty = 0)

El indicador más importante de Sales parece ser la ubicación en estantería (shelving location), ya que la primera rama diferencia las ubicaciones Good (derecha) de las Bad y Medium (izquierda).

Si solo escribimos el nombre del objeto del árbol, R imprime la salida correspondiente a cada rama del árbol. R muestra el criterio de división (p.ej. Price < 92.5), el número de observaciones en esa rama, la desviación, la predicción general para la rama (Yes o No), y la fracción de observaciones en esa rama que toman los valores Yes y No. Las ramas que llevan a nodos terminales se indican con asteriscos.

tree.carseats
node), split, n, deviance, yval, (yprob)
      * denotes terminal node

  1) root 400 541.500 No ( 0.59000 0.41000 )  
    2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )  
      4) Price < 92.5 46  56.530 Yes ( 0.30435 0.69565 )  
        8) Income < 57 10  12.220 No ( 0.70000 0.30000 )  
         16) CompPrice < 110.5 5   0.000 No ( 1.00000 0.00000 ) *
         17) CompPrice > 110.5 5   6.730 Yes ( 0.40000 0.60000 ) *
        9) Income > 57 36  35.470 Yes ( 0.19444 0.80556 )  
         18) Population < 207.5 16  21.170 Yes ( 0.37500 0.62500 ) *
         19) Population > 207.5 20   7.941 Yes ( 0.05000 0.95000 ) *
      5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )  
       10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )  
         20) CompPrice < 124.5 96  44.890 No ( 0.93750 0.06250 )  
           40) Price < 106.5 38  33.150 No ( 0.84211 0.15789 )  
             80) Population < 177 12  16.300 No ( 0.58333 0.41667 )  
              160) Income < 60.5 6   0.000 No ( 1.00000 0.00000 ) *
              161) Income > 60.5 6   5.407 Yes ( 0.16667 0.83333 ) *
             81) Population > 177 26   8.477 No ( 0.96154 0.03846 ) *
           41) Price > 106.5 58   0.000 No ( 1.00000 0.00000 ) *
         21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )  
           42) Price < 122.5 51  70.680 Yes ( 0.49020 0.50980 )  
             84) ShelveLoc: Bad 11   6.702 No ( 0.90909 0.09091 ) *
             85) ShelveLoc: Medium 40  52.930 Yes ( 0.37500 0.62500 )  
              170) Price < 109.5 16   7.481 Yes ( 0.06250 0.93750 ) *
              171) Price > 109.5 24  32.600 No ( 0.58333 0.41667 )  
                342) Age < 49.5 13  16.050 Yes ( 0.30769 0.69231 ) *
                343) Age > 49.5 11   6.702 No ( 0.90909 0.09091 ) *
           43) Price > 122.5 77  55.540 No ( 0.88312 0.11688 )  
             86) CompPrice < 147.5 58  17.400 No ( 0.96552 0.03448 ) *
             87) CompPrice > 147.5 19  25.010 No ( 0.63158 0.36842 )  
              174) Price < 147 12  16.300 Yes ( 0.41667 0.58333 )  
                348) CompPrice < 152.5 7   5.742 Yes ( 0.14286 0.85714 ) *
                349) CompPrice > 152.5 5   5.004 No ( 0.80000 0.20000 ) *
              175) Price > 147 7   0.000 No ( 1.00000 0.00000 ) *
       11) Advertising > 13.5 45  61.830 Yes ( 0.44444 0.55556 )  
         22) Age < 54.5 25  25.020 Yes ( 0.20000 0.80000 )  
           44) CompPrice < 130.5 14  18.250 Yes ( 0.35714 0.64286 )  
             88) Income < 100 9  12.370 No ( 0.55556 0.44444 ) *
             89) Income > 100 5   0.000 Yes ( 0.00000 1.00000 ) *
           45) CompPrice > 130.5 11   0.000 Yes ( 0.00000 1.00000 ) *
         23) Age > 54.5 20  22.490 No ( 0.75000 0.25000 )  
           46) CompPrice < 122.5 10   0.000 No ( 1.00000 0.00000 ) *
           47) CompPrice > 122.5 10  13.860 No ( 0.50000 0.50000 )  
             94) Price < 125 5   0.000 Yes ( 0.00000 1.00000 ) *
             95) Price > 125 5   0.000 No ( 1.00000 0.00000 ) *
    3) ShelveLoc: Good 85  90.330 Yes ( 0.22353 0.77647 )  
      6) Price < 135 68  49.260 Yes ( 0.11765 0.88235 )  
       12) US: No 17  22.070 Yes ( 0.35294 0.64706 )  
         24) Price < 109 8   0.000 Yes ( 0.00000 1.00000 ) *
         25) Price > 109 9  11.460 No ( 0.66667 0.33333 ) *
       13) US: Yes 51  16.880 Yes ( 0.03922 0.96078 ) *
      7) Price > 135 17  22.070 No ( 0.64706 0.35294 )  
       14) Income < 46 6   0.000 No ( 1.00000 0.00000 ) *
       15) Income > 46 11  15.160 Yes ( 0.45455 0.54545 ) *

Para evaluar adecuadamente el rendimiento de un árbol de clasificación en estos datos, debemos estimar el error de prueba en lugar de simplemente calcular el error de entrenamiento. Dividimos las observaciones en un conjunto de entrenamiento y un conjunto de prueba, construimos el árbol usando el conjunto de entrenamiento, y evaluamos su rendimiento en los datos de prueba. La función predict() se puede usar para este propósito. En el caso de un árbol de clasificación, el argumento type = "class" instruye a R para que devuelva la predicción de clase real. Este enfoque lleva a predicciones correctas para alrededor del \(77 \%\) de las ubicaciones en el conjunto de datos de prueba.

set.seed(2)
train <- sample(1:nrow(Carseats), 200)
Carseats.test <- Carseats[-train, ]

tree.carseats <- update(tree.carseats, subset = train)

tree.pred <- predict(tree.carseats, 
                     Carseats.test,
                     type = "class")

confusionMatrix(tree.pred, Carseats.test[['High']])
Confusion Matrix and Statistics

          Reference
Prediction  No Yes
       No  104  33
       Yes  13  50
                                          
               Accuracy : 0.77            
                 95% CI : (0.7054, 0.8264)
    No Information Rate : 0.585           
    P-Value [Acc > NIR] : 2.938e-08       
                                          
                  Kappa : 0.5091          
                                          
 Mcnemar's Test P-Value : 0.005088        
                                          
            Sensitivity : 0.8889          
            Specificity : 0.6024          
         Pos Pred Value : 0.7591          
         Neg Pred Value : 0.7937          
             Prevalence : 0.5850          
         Detection Rate : 0.5200          
   Detection Prevalence : 0.6850          
      Balanced Accuracy : 0.7456          
                                          
       'Positive' Class : No              
                                          

(Si vuelves a ejecutar la función predict() podrías obtener resultados ligeramente diferentes, debido a “empates”: por ejemplo, esto puede suceder cuando las observaciones de entrenamiento que corresponden a un nodo terminal se dividen equitativamente entre los valores de respuesta Yes y No.)

A continuación, consideramos si podar (pruning) el árbol podría llevar a mejores resultados. La función cv.tree() realiza validación cruzada para determinar el nivel óptimo de complejidad del árbol; se utiliza la poda de complejidad de costo (cost complexity pruning) para seleccionar una secuencia de árboles por considerar. Usamos el argumento FUN = prune.misclass para indicar que queremos que la tasa de error de clasificación guíe el proceso de validación cruzada y poda, en lugar del valor predeterminado de la función cv.tree(), que es la desviación. La función cv.tree() reporta el número de nodos terminales de cada árbol considerado (size), así como la tasa de error correspondiente y el valor del parámetro de costo-complejidad utilizado (k, que corresponde a \(\alpha\) en las notas de clase).

set.seed(7)
cv.carseats <- cv.tree(tree.carseats, FUN = prune.misclass)
names(cv.carseats)
[1] "size"   "dev"    "k"      "method"
cv.carseats
$size
[1] 21 19 14  9  8  5  3  2  1

$dev
[1] 75 75 75 74 82 83 83 85 82

$k
[1] -Inf  0.0  1.0  1.4  2.0  3.0  4.0  9.0 18.0

$method
[1] "misclass"

attr(,"class")
[1] "prune"         "tree.sequence"

A pesar de su nombre, dev corresponde al número de errores de validación cruzada. El árbol con 9 nodos terminales resulta en solo 74 errores de validación cruzada. Graficamos la tasa de error como una función tanto de size como de k.

cv_carseats <- cv.carseats |> 
  unclass() |> 
  as.data.frame()

error_size <- ggplot(cv_carseats, aes(x = size, y = dev)) + geom_point() + geom_line()
error_k <- ggplot(cv_carseats, aes(x = k, y = dev)) + geom_point() + geom_line()
error_size + error_k

Ahora aplicamos la función prune.misclass() para podar el árbol y obtener el árbol de nueve nodos.

prune.carseats <- prune.misclass(tree.carseats, best = 9)
plot(prune.carseats)
text(prune.carseats, pretty = 0)

¿Qué tan bien se desempeña este árbol podado en el conjunto de datos de prueba? Una vez más, aplicamos la función predict().

tree.pred <- predict(prune.carseats, 
                     Carseats.test,
                    type = "class")

confusionMatrix(tree.pred, Carseats.test[['High']])
Confusion Matrix and Statistics

          Reference
Prediction No Yes
       No  97  25
       Yes 20  58
                                          
               Accuracy : 0.775           
                 95% CI : (0.7108, 0.8309)
    No Information Rate : 0.585           
    P-Value [Acc > NIR] : 1.206e-08       
                                          
                  Kappa : 0.5325          
                                          
 Mcnemar's Test P-Value : 0.551           
                                          
            Sensitivity : 0.8291          
            Specificity : 0.6988          
         Pos Pred Value : 0.7951          
         Neg Pred Value : 0.7436          
             Prevalence : 0.5850          
         Detection Rate : 0.4850          
   Detection Prevalence : 0.6100          
      Balanced Accuracy : 0.7639          
                                          
       'Positive' Class : No              
                                          

Ahora el \(77.5 \%\) de las observaciones de prueba se clasifican correctamente, por lo que el proceso de poda no solo ha producido un árbol más interpretable, sino que también ha mejorado ligeramente la precisión de la clasificación.

Si aumentamos el valor de best, obtenemos un árbol podado más grande con una precisión de clasificación más baja:

prune.carseats <- prune.misclass(tree.carseats, best = 14)
plot(prune.carseats)
text(prune.carseats, pretty = 0)

tree.pred <- predict(prune.carseats, Carseats.test,
    type = "class")
confusionMatrix(tree.pred, Carseats.test[['High']])
Confusion Matrix and Statistics

          Reference
Prediction  No Yes
       No  102  31
       Yes  15  52
                                          
               Accuracy : 0.77            
                 95% CI : (0.7054, 0.8264)
    No Information Rate : 0.585           
    P-Value [Acc > NIR] : 2.938e-08       
                                          
                  Kappa : 0.5127          
                                          
 Mcnemar's Test P-Value : 0.02699         
                                          
            Sensitivity : 0.8718          
            Specificity : 0.6265          
         Pos Pred Value : 0.7669          
         Neg Pred Value : 0.7761          
             Prevalence : 0.5850          
         Detection Rate : 0.5100          
   Detection Prevalence : 0.6650          
      Balanced Accuracy : 0.7492          
                                          
       'Positive' Class : No              
                                          

Ajustando Árboles de Regresión

Aquí ajustamos un árbol de regresión al conjunto de datos Boston. Primero, creamos un conjunto de entrenamiento y ajustamos el árbol a los datos de entrenamiento.

set.seed(1)
train <- sample(nrow(Boston), size = nrow(Boston) / 2)
tree_boston <- tree(medv ~ ., data = Boston, subset = train)
summary(tree_boston)

Regression tree:
tree(formula = medv ~ ., data = Boston, subset = train)
Variables actually used in tree construction:
[1] "rm"    "lstat" "crim"  "age"  
Number of terminal nodes:  7 
Residual mean deviance:  10.38 = 2555 / 246 
Distribution of residuals:
    Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
-10.1800  -1.7770  -0.1775   0.0000   1.9230  16.5800 

Observa que la salida de summary() indica que solo cuatro de las variables se han utilizado en la construcción del árbol. En el contexto de un árbol de regresión, la desviación es simplemente la suma de errores cuadráticos para el árbol. Ahora graficamos el árbol.

plot(tree_boston)
text(tree_boston, pretty = 0)

La variable lstat mide el porcentaje de individuos con {menor estatus socioeconómico}, mientras que la variable rm corresponde al número promedio de habitaciones. El árbol indica que valores más grandes de rm, o valores más bajos de lstat, corresponden a casas más caras. Por ejemplo, el árbol predice un precio mediano de vivienda de $\(45{,}400\) para hogares en tramos censales (census tracts) en los que rm >= 7.553.

Vale la pena notar que podríamos haber ajustado un árbol mucho más grande, pasando control = tree.control(nobs = length(train), mindev = 0) a la función tree().

Ahora usamos la función cv.tree() para ver si podar el árbol mejorará el rendimiento.

cv_boston <- cv.tree(tree_boston)

cv_boston_data <- cv_boston |> 
  unclass() |> 
  as.data.frame()

cv_boston_data |> 
  ggplot(aes(x = size, y = dev)) +
  geom_point() + geom_line()

En este caso, el árbol más complejo bajo consideración es seleccionado por validación cruzada. Sin embargo, si deseamos podar el árbol, podríamos hacerlo usando la función prune.tree():

prune_boston <- prune.tree(tree_boston, best = 5)
plot(prune_boston)
text(prune_boston, pretty = 0)

De acuerdo con los resultados de la validación cruzada, usamos el árbol sin podar para hacer predicciones en el conjunto de prueba.

yhat <- predict(tree_boston, newdata = Boston[-train, ])
boston_test <- Boston[-train, "medv"]

yhat_plot_data <- data.frame(
  y_hat = yhat,
  y_test = boston_test
)

yhat_plot_data |> 
  ggplot(aes(x = yhat, y = y_test)) + 
  geom_jitter(alpha = 0.25) +
  geom_smooth(method = 'lm', se = FALSE, colour='black', linetype = 'dashed')
`geom_smooth()` using formula = 'y ~ x'

mean((yhat - boston_test)^2)
[1] 35.28688

En otras palabras, el ECM (Error Cuadrático Medio) del conjunto de prueba asociado con el árbol de regresión es \(35.29\). La raíz cuadrada del ECM es, por lo tanto, alrededor de \(5.941\), lo que indica que este modelo conduce a predicciones de prueba que están (en promedio) dentro de aproximadamente $\(5{,}941\) del verdadero valor mediano de la vivienda para el tramo censal.

Bagging y Bosques Aleatorios

Aquí aplicamos bagging y bosques aleatorios a los datos Boston, usando el paquete randomForest en R. Los resultados exactos obtenidos en esta sección pueden depender de la versión de R y la versión del paquete randomForest instaladas en tu computadora. Recuerda que bagging es simplemente un caso especial de un bosque aleatorio con \(m=p\). Por lo tanto, la función randomForest() se puede usar para realizar tanto bosques aleatorios como bagging. Realizamos bagging de la siguiente manera:

library(randomForest)
randomForest 4.7-1.2
Type rfNews() to see new features/changes/bug fixes.

Attaching package: 'randomForest'
The following object is masked from 'package:dplyr':

    combine
The following object is masked from 'package:ggplot2':

    margin
set.seed(1)
bag_boston <- randomForest(medv ~ ., data = Boston,
    subset = train, mtry = 12, importance = TRUE)
bag_boston

Call:
 randomForest(formula = medv ~ ., data = Boston, mtry = 12, importance = TRUE,      subset = train) 
               Type of random forest: regression
                     Number of trees: 500
No. of variables tried at each split: 12

          Mean of squared residuals: 11.40162
                    % Var explained: 85.17

El argumento mtry = 12 indica que los \(12\) predictores deben ser considerados para cada división del árbol, en otras palabras, que se debe hacer bagging. ¿Qué tan bien se desempeña este modelo de bagging en el conjunto de prueba?

yhat_bag <- predict(bag_boston, newdata = Boston[-train, ])

plot(yhat_bag, boston_test)
abline(0, 1)

mean((yhat.bag - boston.test)^2)
Error: object 'yhat.bag' not found

El ECM del conjunto de prueba asociado con el árbol de regresión con bagging es \(23.42\), aproximadamente dos tercios del obtenido usando un solo árbol podado óptimamente. Podríamos cambiar el número de árboles cultivados por randomForest() usando el argumento ntree:

bag.boston <- randomForest(medv ~ ., data = Boston,
    subset = train, mtry = 12, ntree = 25)
yhat.bag <- predict(bag.boston, newdata = Boston[-train, ])
mean((yhat.bag - boston.test)^2)
Error: object 'boston.test' not found

Cultivar un bosque aleatorio procede exactamente de la misma manera, excepto que usamos un valor más pequeño del argumento mtry. Por defecto, randomForest() usa \(p/3\) variables al construir un bosque aleatorio de árboles de regresión, y \(\sqrt{p}\) variables al construir un bosque aleatorio de árboles de clasificación. Aquí usamos mtry = 6.

set.seed(1)
rf.boston <- randomForest(medv ~ ., data = Boston,
    subset = train, mtry = 6, importance = TRUE)
yhat.rf <- predict(rf.boston, newdata = Boston[-train, ])
mean((yhat.rf - boston.test)^2)
Error: object 'boston.test' not found

El ECM del conjunto de prueba es \(20.07\); esto indica que los bosques aleatorios produjeron una mejora sobre bagging en este caso.

Usando la función importance(), podemos ver la importancia de cada variable.

importance(rf.boston)
          %IncMSE IncNodePurity
crim    19.435587    1070.42307
zn       3.091630      82.19257
indus    6.140529     590.09536
chas     1.370310      36.70356
nox     13.263466     859.97091
rm      35.094741    8270.33906
age     15.144821     634.31220
dis      9.163776     684.87953
rad      4.793720      83.18719
tax      4.410714     292.20949
ptratio  8.612780     902.20190
lstat   28.725343    5813.04833

Se reportan dos medidas de importancia de variable. La primera se basa en la disminución media de la precisión en las predicciones en las muestras “out of bag” cuando una variable dada es permutada. La segunda es una medida de la disminución total en la impureza del nodo que resulta de las divisiones sobre esa variable, promediada sobre todos los árboles (esto se graficó en la Figura 8.9). En el caso de árboles de regresión, la impureza del nodo se mide por el RSS (suma residual de cuadrados) de entrenamiento, y para árboles de clasificación por la devianza. Se pueden producir gráficos de estas medidas de importancia usando la función varImpPlot().

varImpPlot(rf.boston)

Los resultados indican que, a través de todos los árboles considerados en el bosque aleatorio, la riqueza de la comunidad (lstat) y el tamaño de la casa (rm) son, con diferencia, las dos variables más importantes.

Boosting

Aquí usamos el paquete gbm, y dentro de él la función gbm(), para ajustar árboles de regresión potenciados (boosted) a los datos Boston. Ejecutamos gbm() con la opción distribution = "gaussian" ya que este es un problema de regresión; si fuera un problema de clasificación binaria, usaríamos distribution = "bernoulli". El argumento n.trees = 5000 indica que queremos \(5000\) árboles, y la opción interaction.depth = 4 limita la profundidad de cada árbol.

library(gbm)
Loaded gbm 2.2.2
This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
set.seed(1)
boost.boston <- gbm(medv ~ ., data = Boston[train, ],
    distribution = "gaussian", n.trees = 5000,
    interaction.depth = 4)

La función summary() produce un gráfico de influencia relativa y también emite las estadísticas de influencia relativa.

summary(boost.boston)

            var     rel.inf
rm           rm 44.48249588
lstat     lstat 32.70281223
crim       crim  4.85109954
dis         dis  4.48693083
nox         nox  3.75222394
age         age  3.19769210
ptratio ptratio  2.81354826
tax         tax  1.54417603
indus     indus  1.03384666
rad         rad  0.87625748
zn           zn  0.16220479
chas       chas  0.09671228

Vemos que lstat y rm son, con diferencia, las variables más importantes. También podemos producir gráficos de dependencia parcial para estas dos variables. Estos gráficos ilustran el efecto marginal de las variables seleccionadas en la respuesta después de integrar las otras variables. En este caso, como podríamos esperar, los precios medianos de las viviendas aumentan con rm y disminuyen con lstat.

plot(boost.boston, i = "rm")

plot(boost.boston, i = "lstat")

Ahora usamos el modelo potenciado para predecir medv en el conjunto de prueba:

yhat.boost <- predict(boost.boston,
    newdata = Boston[-train, ], n.trees = 5000)
mean((yhat.boost - boston.test)^2)
Error: object 'boston.test' not found

El ECM de prueba obtenido es \(18.39\): esto es superior al ECM de prueba de los bosques aleatorios y bagging. Si queremos, podemos realizar boosting con un valor diferente del parámetro de contracción (shrinkage) \(\lambda\) en (8.10). El valor predeterminado es \(0.001\), pero esto se modifica fácilmente. Aquí tomamos \(\lambda=0.2\).

boost.boston <- gbm(medv ~ ., data = Boston[train, ],
    distribution = "gaussian", n.trees = 5000,
    interaction.depth = 4, shrinkage = 0.2, verbose = F)
yhat.boost <- predict(boost.boston,
    newdata = Boston[-train, ], n.trees = 5000)
mean((yhat.boost - boston.test)^2)
Error: object 'boston.test' not found

En este caso, usar \(\lambda=0.2\) lleva a un ECM de prueba más bajo que \(\lambda=0.001\).

Árboles de Regresión Aditivos Bayesianos (BART)

En esta sección usamos el paquete BART, y dentro de él la función gbart(), para ajustar un modelo de árboles de regresión aditivos bayesianos a los datos de vivienda Boston. La función gbart() está diseñada para variables de resultado cuantitativas. Para resultados binarios, lbart() y pbart() están disponibles.

Para ejecutar la función gbart(), primero debemos crear matrices de predictores para los datos de entrenamiento y prueba. Ejecutamos BART con la configuración predeterminada.

library(BART)
Loading required package: nlme

Attaching package: 'nlme'
The following object is masked from 'package:dplyr':

    collapse
Loading required package: survival

Attaching package: 'survival'
The following object is masked from 'package:caret':

    cluster
x <- Boston[, 1:12]
y <- Boston[, "medv"]
xtrain <- x[train, ]
ytrain <- y[train]
xtest <- x[-train, ]
ytest <- y[-train]
set.seed(1)
bartfit <- gbart(xtrain, ytrain, x.test = xtest)
*****Calling gbart: type=1
*****Data:
data:n,p,np: 253, 12, 253
y1,yn: 0.213439, -5.486561
x1,x[n*p]: 0.109590, 20.080000
xp1,xp[np*p]: 0.027310, 7.880000
*****Number of Trees: 200
*****Number of Cut Points: 100 ... 100
*****burn,nd,thin: 100,1000,1
*****Prior:beta,alpha,tau,nu,lambda,offset: 2,0.95,0.795495,3,3.71636,21.7866
*****sigma: 4.367914
*****w (weights): 1.000000 ... 1.000000
*****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,12,0
*****printevery: 100

MCMC
done 0 (out of 1100)
done 100 (out of 1100)
done 200 (out of 1100)
done 300 (out of 1100)
done 400 (out of 1100)
done 500 (out of 1100)
done 600 (out of 1100)
done 700 (out of 1100)
done 800 (out of 1100)
done 900 (out of 1100)
done 1000 (out of 1100)
time: 2s
trcnt,tecnt: 1000,1000

A continuación calculamos el error de prueba.

yhat.bart <- bartfit$yhat.test.mean
mean((ytest - yhat.bart)^2)
[1] 15.91912

En este conjunto de datos, el error de prueba de BART es más bajo que el error de prueba de bosques aleatorios y boosting.

Ahora podemos verificar cuántas veces apareció cada variable en la colección de árboles.

ord <- order(bartfit$varcount.mean, decreasing = T)
bartfit$varcount.mean[ord]
    nox   lstat     rad      rm     tax ptratio    chas     age   indus      zn 
 22.973  21.653  21.638  20.725  20.021  19.615  19.283  19.278  19.073  15.576 
    dis    crim 
 13.800  11.607