Introducción

Los árboles de decisión son una herramienta extremadamente útil en los procesos de minería de datos y generación de conocimiento.

Su ventaja principal reside en su simplicidad. Los arboles de decisión son fácilmente interpretables por cualquier persona sin necesidad de conocimientos estadísticos.

La representación clara de las ideas que transmite un árbol de decisión parecería ser una tarea sencilla pero a veces no lo es. Las librerias como rpart, party o partykit suelen generar representaciones confusas o poco prácticas para arboles muy complejos.

Con el objetivo de solucionar las dificultades mencionadas previamente, desarrolle la función Plotear_Arbol para generar visualizaciones amigables de árboles provenientes de partykit. Con apoyo de librerías orientadas a representar redes (visnetwork) y utilizando los objetos previamente generados por partykit, se combinan la estructura de los modelos de decisión junto con conceptos simples de entender para el usuario (como escalas de colores y tamaños variables).

Por otro lado, extraer las reglas principales de un árbol suele ser una tarea tediosa. Para resolver este último problema, los recursos de la librería flextable son ideales para construir tablas resumen de la estructura de estos arboles de decisión. La función Tabla_Arbol es mi aporte para resolver esta cuestión.

Obtener las funciones

Las funciones pueden obtenerse facilmente desde el siguiente Link. Luego de descargar los archivos .rds, crear las funciones es muy fácil:

readRDS("Plotear_Arbol.rds")->Plotear_Arbol
readRDS("Tabla_Arbol.rds")->Tabla_Arbol

Objeto partykit

Ambas funciones, Plotear_Arbol y Tabla_Arbol trabajan en base a un objeto partykit. A continuación, un ejemplo de cómo generar un modelo.

Dataset

El origen de datos que vamos a utilizar es el dataset diamonds del paquete ggplot2.

library(ggplot2)
diamonds
## # A tibble: 53,940 x 10
##    carat cut       color clarity depth table price     x     y     z
##    <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
##  1 0.23  Ideal     E     SI2      61.5    55   326  3.95  3.98  2.43
##  2 0.21  Premium   E     SI1      59.8    61   326  3.89  3.84  2.31
##  3 0.23  Good      E     VS1      56.9    65   327  4.05  4.07  2.31
##  4 0.290 Premium   I     VS2      62.4    58   334  4.2   4.23  2.63
##  5 0.31  Good      J     SI2      63.3    58   335  4.34  4.35  2.75
##  6 0.24  Very Good J     VVS2     62.8    57   336  3.94  3.96  2.48
##  7 0.24  Very Good I     VVS1     62.3    57   336  3.95  3.98  2.47
##  8 0.26  Very Good H     SI1      61.9    55   337  4.07  4.11  2.53
##  9 0.22  Fair      E     VS2      65.1    61   337  3.87  3.78  2.49
## 10 0.23  Very Good H     VS1      59.4    61   338  4     4.05  2.39
## # ... with 53,930 more rows

Entrenar el Arbol

El árbol de decisión va a ser entrenado utilizando como target la variable “price” y como predictores “cut”, “color”, “depth” y “table”. Utilizamos la función ctree del paquete partykit.

Se aplica el parámetro minbucket=5000 en busca de que los nodos finales del árbol no sean demasiado pequeños y la estructura del mismo no sea muy compleja.

library(partykit)
ctree(price~cut+color+depth+table,data = diamonds,
      control = ctree_control(minbucket = 3000))->objeto
objeto
## 
## Model formula:
## price ~ cut + color + depth + table
## 
## Fitted party:
## [1] root
## |   [2] color <= G
## |   |   [3] table <= 56.2
## |   |   |   [4] color <= E
## |   |   |   |   [5] depth <= 61.9: 2431.640 (n = 3056, err = 21270845792.1)
## |   |   |   |   [6] depth > 61.9: 2723.681 (n = 3195, err = 32362339136.4)
## |   |   |   [7] color > E
## |   |   |   |   [8] table <= 55.3: 3180.086 (n = 4003, err = 56078366079.6)
## |   |   |   |   [9] table > 55.3: 3467.166 (n = 4073, err = 54405226329.8)
## |   |   [10] table > 56.2
## |   |   |   [11] color <= E
## |   |   |   |   [12] table <= 57.7: 2972.188 (n = 3026, err = 30825126877.4)
## |   |   |   |   [13] table > 57.7: 3631.568 (n = 7295, err = 9.7525e+10)
## |   |   |   [14] color > E
## |   |   |   |   [15] color <= F: 4031.954 (n = 5937, err = 90143968004.6)
## |   |   |   |   [16] color > F: 4385.243 (n = 6821, err = 117216287159.0)
## |   [17] color > G
## |   |   [18] table <= 56.1: 3809.161 (n = 5591, err = 90916434219.8)
## |   |   [19] table > 56.1
## |   |   |   [20] color <= H: 4974.747 (n = 5381, err = 1.00664e+11)
## |   |   |   [21] color > H: 5708.125 (n = 5562, err = 123990279494.9)
## 
## Number of inner nodes:    10
## Number of terminal nodes: 11

Visualización de partykit

Al intentar obtener una visualización del árbol nos encontramos con lo siguiente:

plot(objeto)

Plotear_Arbol

Los nodos finales del árbol se superponen entre si y es difícil distinguir el impacto de los cortes en variable respuesta o en la cantidad de observaciones que caen en cada bucket.

La función Plotear_Arbol aplica una escala de colores entre todos los nodos del árbol (ya sean nodo de corte o finales) de tal forma que a simple vista pueda apreciarse los efectos de los cortes sobre la variable respuesta.

Por otro lado, los nodos finales poseen un tamaño diferente de acuerdo a la cantidad de observaciones que contienen.

Independientemente de la información que puede obtenerse a golpe de vista también pueden obtenerse datos relevantes de cada nodo de corte o final con solo posicionar el cursor sobre cada uno de ellos.

La posibilidad del Zoom permite también acercar o alejar la imagen hasta obtener la visualización ideal.

Ejemplo de Uso

Visualización por Default

Visualización sin modificar parámetros:

Plotear_Arbol(objeto)

Cambiando la gama de colores

Puede usarse otra combinación de colores:

Plotear_Arbol(objeto,colores = c("blue","pink","red"))

Colocando cada Nodo Final en su nivel

Los nodos finales poseen un nivel propio según el orden de la estructura donde fueron generados. Por defecto, la función los ubica a todos en el nivel mas bajo pero esto puede desactivarse:

Plotear_Arbol(objeto,Nodos.Finales.Abajo = FALSE)

Anulando Configuración de Tamaños

Puede anularse la vinculación del tamaño de los nodos finales a la cantidad de observaciones contenidas en cada uno:

Plotear_Arbol(objeto,Nodos.Finales.Size = FALSE)

Utilizando Iconos en vez de Figuras

Pueden utilizarse iconos en vez de las figuras tradicionales (El catalogo puede encontrarse en: https://fontawesome.com/icons)

Plotear_Arbol(objeto,Icono = "f0c0")

Código de la Función

A continuación el código completo para crear la función:

Plotear_Arbol<-function(objeto,
                        colores=c("red","yellow","darkgreen"),
                        Nodos.Finales.Abajo=TRUE,
                        Nodos.Finales.Size=TRUE,
                        Icono=NULL) {

require(partykit)
require(stringr)
require(dplyr)
require(data.table)
require(visNetwork)
require(grDevices)

## extraer reglas
partykit:::.list.rules.party(objeto)->reglas

## partirlas por "&"
str_split(reglas," & ")->reglas2

## Crear data.table con todas las particiones
lapply(reglas2,function(x) {
  
  regla_completa<-""  
  for (i in 1:length(x)) {
    regla_completa[i]<-paste(x[1:i],collapse ="|")
  }
  
  data.table(
    Regla=x,
    Nivel=1:length(x),
    Identificacion=regla_completa,
    Identificacion_Ant=c("Inicio",regla_completa[-length(x)])
  )
})->reglas2
for (i in 1:length(reglas2)) {reglas2[[i]]$Nodo_Orig<-names(reglas)[i]}
reglas3<-rbindlist(reglas2)

## desarmar la particion entre variable, valor y operador
rbindlist(lapply(strsplit(reglas3$Regla, " > | <= | %in% "),function(x) {
  data.table(Variable=x[1],Valor=x[2])
}))->composicion

composicion$Valor<-str_remove_all(composicion$Valor,"%in%")
composicion$Valor<-str_remove_all(composicion$Valor,"c\\(")
composicion$Valor<-str_remove_all(composicion$Valor,"\\)")
composicion$Valor<-unlist(lapply(str_split(composicion$Valor,","),function(x) {
  paste(x[!x %in% c(" \"NA\"","\"NA\"")],collapse = "\n")
}))

composicion$Operador<-ifelse(str_detect(reglas3$Regla,"<="),"<=",
         ifelse(str_detect(reglas3$Regla,">"),">","%in%"))
cbind(reglas3,composicion)->reglas4

## identifico las reglas del registro y su antecesor
data.table(
  Identificacion=c("Inicio",unique(reglas4$Identificacion))
)->referencias
referencias$ID<-1:nrow(referencias)
reglas4 %>% left_join(referencias,by="Identificacion")->reglas4
names(referencias)<-c("Identificacion_Ant","ID_Prev")
reglas4 %>% left_join(referencias,by="Identificacion_Ant")->reglas4
reglas4$Identificacion_Ant<-NULL
reglas4$Identificacion<-NULL
reglas4$Regla<-NULL

## identifico los nodos finales
reglas4 %>% group_by(Nodo_Orig) %>% summarise(MaxNivel=max(Nivel)) %>%
  right_join(reglas4,by="Nodo_Orig")->reglas4
reglas4$EsNodoFinal<-ifelse(reglas4$Nivel==reglas4$MaxNivel,1,0)

## calculo sus metricas
data.table(
  Nodo_Orig=predict(objeto,type="node"),
  Estimacion=predict(objeto,type="response"),
  Cantidad=1
) %>% group_by(Nodo_Orig) %>% summarise(Estimacion=max(Estimacion),
                                        Cantidad=sum(Cantidad))->ref_nodos

## le calculo las metricas a todos los nodos
reglas4$Nodo_Orig<-as.numeric(reglas4$Nodo_Orig)
ref_nodos %>% right_join(reglas4,by="Nodo_Orig")->reglas5

## construyo la tabla de nodos previos
reglas5 %>% select(ID_Prev,Variable,Nivel)->nodos
nodos[!duplicated(nodos),]->nodos
names(nodos)<-c("id","label","level")
nodos$shape<-"dot"
nodos$value<-10

## construyo la tabla de nodos finales
reglas5 %>% filter(EsNodoFinal==1) %>% select(Nodo_Orig,ID,Variable,Nivel)->finales
finales$Nodo_Orig<-as.numeric(finales$Nodo_Orig)
right_join(ref_nodos,finales,by="Nodo_Orig")->finales
finales[!duplicated(finales),]->finales

## unifico ambas tablas
round(15*finales$Cantidad/max(finales$Cantidad),0)->tamanio
data.table(
  id=finales$ID,
  label=paste("Cant:",finales$Cantidad,"\n Est:",round(finales$Estimacion,2)),
  level=finales$Nivel+1,
  shape="square",
  value=ifelse(tamanio==0,1,tamanio)
) %>% rbind(nodos)->nodos_todos

## contruyo la tabla de edges
reglas4 %>% select(ID_Prev,ID,Valor,Operador)->uniones
uniones[!duplicated(uniones),]->uniones
edges <- data.frame(from = uniones$ID_Prev, 
                    to = uniones$ID,
                    label=paste(ifelse(uniones$Operador=="%in%",
                                       "",uniones$Operador),uniones$Valor)
)


## le agrego la estimacion a todos los nodos 
  nodos_todos$ID=nodos_todos$id
  
  nodos_todos %>% left_join(
    reglas5 %>% 
      filter(!duplicated(reglas5 %>% 
                           select(ID,Nodo_Orig))) %>% 
      select(ID,Nodo_Orig)
    ,"ID") %>% left_join(ref_nodos,"Nodo_Orig") %>% 
    group_by(ID) %>% 
    summarise(Estimacion_Calculada=sum(Estimacion*Cantidad)/sum(Cantidad),
              Cantidad_Calculada=sum(Cantidad))->Estimacion_calculada
  Estimacion_calculada$Cantidad_Calculada[Estimacion_calculada$ID==1]<-sum(ref_nodos$Cantidad)
  Estimacion_calculada$Estimacion_Calculada[Estimacion_calculada$ID==1]<-
    sum(ref_nodos$Estimacion*ref_nodos$Cantidad)/sum(ref_nodos$Cantidad)
  
  nodos_todos %>% left_join(Estimacion_calculada,by="ID")->nodos_todos2
  
  rbPal <- colorRampPalette(colores)
  nodos_todos2$color<-rbPal(nrow(nodos_todos2))[as.numeric(cut(1-nodos_todos2$Estimacion_Calculada,
                                                               breaks = nrow(nodos_todos2)))]
  
  ## ajustes menores
  nodos_todos2$shadow = TRUE
  nodos_todos2$title=paste0("<p><b>ID:",nodos_todos2$id,"</b><br> Est:",
                            format(round(nodos_todos2$Estimacion_Calculada,4),big.mark = ".",decimal.mark = ","),
                            "<br> Cant:",
                            format(nodos_todos2$Cantidad_Calculada,big.mark = ".",decimal.mark = ","),
                            "</p>")
  
## todos los nodos finales al mismo nivel
if (Nodos.Finales.Abajo) {
  nodos_todos2$level<-ifelse(nodos_todos2$shape!="dot",
                             max(nodos_todos2$level),nodos_todos2$level)
}

if (!Nodos.Finales.Size) {
  nodos_todos2$value<-15
}
  
nodos_todos2$group<-nodos_todos2$color

## construyo la red

## con iconos
if (!is.null(Icono)) {
  nodos_todos2$shape<-NULL
  grafo<-visNetwork(nodos_todos2, edges,width = "100%")%>% 
    visHierarchicalLayout() 
  
  for (i in unique(nodos_todos2$group)) {
    grafo<-grafo %>%
      visGroups(groupname = i, shape = "icon", 
                icon = list(code = Icono, color=i))
  }
  
  grafo<-grafo %>%
    addFontAwesome()
  
} else { ## sin iconos
  
  grafo<-visNetwork(nodos_todos2, edges,width = "100%")%>% 
    visHierarchicalLayout() 
}

return(grafo)
}

Tabla_Arbol

Al final del día, una buena visualización requiere una tabla de soporte. La función Tabla_Arbol expone de forma simple los cortes del árbol y permite compartir estas ideas de forma más práctica con clientes y otras áreas:

Ejemplo de Uso

Tabla_Arbol(objeto)

Condicion.1

Condicion.2

Condicion.3

Condicion.4

Condicion.5

Estimacion

Cantidad

carat <= 0.99

carat <= 0.62

carat <= 0.45

x <= 4.62

clarity in "I1", "SI1", "VS1", "VVS1"

634.260

5428

clarity in "SI2", "VS2", "VVS2", "IF"

756.302

6452

x > 4.62

960.287

5409

carat > 0.45

1675.049

7498

carat > 0.62

y <= 5.83

2588.292

5038

y > 5.83

3527.311

5055

carat > 0.99

carat <= 1.49

clarity in "I1", "SI1", "VS1", "VVS1"

4923.082

7033

clarity in "SI2", "VS2", "VVS2", "IF"

7617.413

5792

carat > 1.49

12260.564

6235

Codigo de la Función

Tabla_Arbol<-function(objeto) {

require(flextable)
require(partykit)
require(dbplyr)
require(stringr)
require(data.table)

## extraer reglas
partykit:::.list.rules.party(objeto)->reglas

## partirlas por "&"
str_split(reglas," & ")->reglas2

## Crear data.table con todas las particiones
lapply(reglas2,function(x) {
  
  regla_completa<-""  
  for (i in 1:length(x)) {
    regla_completa[i]<-paste(x[1:i],collapse ="|")
  }
  
  data.table(
    Regla=x,
    Nivel=1:length(x),
    Identificacion=regla_completa,
    Identificacion_Ant=c("Inicio",regla_completa[-length(x)])
  )
})->reglas2
for (i in 1:length(reglas2)) {reglas2[[i]]$Nodo_Orig<-names(reglas)[i]}
reglas3<-rbindlist(reglas2)

## desarmar la particion entre variable, valor y operador
rbindlist(lapply(strsplit(reglas3$Regla, " > | <= | %in% "),function(x) {
  data.table(Variable=x[1],Valor=x[2])
}))->composicion

composicion$Valor<-str_remove_all(composicion$Valor,"%in%")
composicion$Valor<-str_remove_all(composicion$Valor,"c\\(")
composicion$Valor<-str_remove_all(composicion$Valor,"\\)")
composicion$Valor<-unlist(lapply(str_split(composicion$Valor,","),function(x) {
  paste(x[!x %in% c(" \"NA\"","\"NA\"")],collapse = ",")
}))

composicion$Operador<-ifelse(str_detect(reglas3$Regla,"<="),"<=",
                             ifelse(str_detect(reglas3$Regla,">"),">","in"))
cbind(reglas3,composicion)->reglas4

reglas4$Condicion<-paste(reglas4$Variable,reglas4$Operador,reglas4$Valor)
reglas4 %>% select(Nodo_Orig,Nivel,Condicion) %>%
  reshape(timevar="Nivel",idvar="Nodo_Orig",direction="wide")->resumen

## calculo sus metricas
data.table(
  Nodo_Orig=predict(objeto,type="node"),
  Estimacion=predict(objeto,type="response"),
  Cantidad=1
) %>% group_by(Nodo_Orig) %>% summarise(Estimacion=max(Estimacion),
                                        Cantidad=sum(Cantidad))->ref_nodos

resumen$Nodo_Orig<-as.numeric(as.character(resumen$Nodo_Orig))
resumen %>% left_join(ref_nodos,"Nodo_Orig")->resumen2

resumen2<-resumen2 %>% arrange(Nodo_Orig)
resumen2$Cantidad<-as.integer(resumen2$Cantidad)
flextable(resumen2[,-1]) %>% theme_box()->hh
hh1<-merge_v(hh, j = c(names(resumen2)[-c(1,ncol(resumen2)-1,ncol(resumen2))]))
return(hh1)
}