Los árboles de decisión son una herramienta extremadamente útil en los procesos de minería de datos y generación de conocimiento.
Su ventaja principal reside en su simplicidad. Los arboles de decisión son fácilmente interpretables por cualquier persona sin necesidad de conocimientos estadísticos.
La representación clara de las ideas que transmite un árbol de decisión parecería ser una tarea sencilla pero a veces no lo es. Las librerias como rpart, party o partykit suelen generar representaciones confusas o poco prácticas para arboles muy complejos.
Con el objetivo de solucionar las dificultades mencionadas previamente, desarrolle la función Plotear_Arbol para generar visualizaciones amigables de árboles provenientes de partykit. Con apoyo de librerías orientadas a representar redes (visnetwork) y utilizando los objetos previamente generados por partykit, se combinan la estructura de los modelos de decisión junto con conceptos simples de entender para el usuario (como escalas de colores y tamaños variables).
Por otro lado, extraer las reglas principales de un árbol suele ser una tarea tediosa. Para resolver este último problema, los recursos de la librería flextable son ideales para construir tablas resumen de la estructura de estos arboles de decisión. La función Tabla_Arbol es mi aporte para resolver esta cuestión.
Las funciones pueden obtenerse facilmente desde el siguiente Link. Luego de descargar los archivos .rds, crear las funciones es muy fácil:
readRDS("Plotear_Arbol.rds")->Plotear_Arbol
readRDS("Tabla_Arbol.rds")->Tabla_Arbol
Ambas funciones, Plotear_Arbol y Tabla_Arbol trabajan en base a un objeto partykit. A continuación, un ejemplo de cómo generar un modelo.
El origen de datos que vamos a utilizar es el dataset diamonds del paquete ggplot2.
library(ggplot2)
diamonds
## # A tibble: 53,940 x 10
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
## 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
## 4 0.290 Premium I VS2 62.4 58 334 4.2 4.23 2.63
## 5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
## 6 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
## 7 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
## 8 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
## 9 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
## 10 0.23 Very Good H VS1 59.4 61 338 4 4.05 2.39
## # ... with 53,930 more rows
El árbol de decisión va a ser entrenado utilizando como target la variable “price” y como predictores “cut”, “color”, “depth” y “table”. Utilizamos la función ctree del paquete partykit.
Se aplica el parámetro minbucket=5000 en busca de que los nodos finales del árbol no sean demasiado pequeños y la estructura del mismo no sea muy compleja.
library(partykit)
ctree(price~cut+color+depth+table,data = diamonds,
control = ctree_control(minbucket = 3000))->objeto
objeto
##
## Model formula:
## price ~ cut + color + depth + table
##
## Fitted party:
## [1] root
## | [2] color <= G
## | | [3] table <= 56.2
## | | | [4] color <= E
## | | | | [5] depth <= 61.9: 2431.640 (n = 3056, err = 21270845792.1)
## | | | | [6] depth > 61.9: 2723.681 (n = 3195, err = 32362339136.4)
## | | | [7] color > E
## | | | | [8] table <= 55.3: 3180.086 (n = 4003, err = 56078366079.6)
## | | | | [9] table > 55.3: 3467.166 (n = 4073, err = 54405226329.8)
## | | [10] table > 56.2
## | | | [11] color <= E
## | | | | [12] table <= 57.7: 2972.188 (n = 3026, err = 30825126877.4)
## | | | | [13] table > 57.7: 3631.568 (n = 7295, err = 9.7525e+10)
## | | | [14] color > E
## | | | | [15] color <= F: 4031.954 (n = 5937, err = 90143968004.6)
## | | | | [16] color > F: 4385.243 (n = 6821, err = 117216287159.0)
## | [17] color > G
## | | [18] table <= 56.1: 3809.161 (n = 5591, err = 90916434219.8)
## | | [19] table > 56.1
## | | | [20] color <= H: 4974.747 (n = 5381, err = 1.00664e+11)
## | | | [21] color > H: 5708.125 (n = 5562, err = 123990279494.9)
##
## Number of inner nodes: 10
## Number of terminal nodes: 11
Al intentar obtener una visualización del árbol nos encontramos con lo siguiente:
plot(objeto)
Los nodos finales del árbol se superponen entre si y es difícil distinguir el impacto de los cortes en variable respuesta o en la cantidad de observaciones que caen en cada bucket.
La función Plotear_Arbol aplica una escala de colores entre todos los nodos del árbol (ya sean nodo de corte o finales) de tal forma que a simple vista pueda apreciarse los efectos de los cortes sobre la variable respuesta.
Por otro lado, los nodos finales poseen un tamaño diferente de acuerdo a la cantidad de observaciones que contienen.
Independientemente de la información que puede obtenerse a golpe de vista también pueden obtenerse datos relevantes de cada nodo de corte o final con solo posicionar el cursor sobre cada uno de ellos.
La posibilidad del Zoom permite también acercar o alejar la imagen hasta obtener la visualización ideal.
Visualización sin modificar parámetros:
Plotear_Arbol(objeto)
Puede usarse otra combinación de colores:
Plotear_Arbol(objeto,colores = c("blue","pink","red"))
Los nodos finales poseen un nivel propio según el orden de la estructura donde fueron generados. Por defecto, la función los ubica a todos en el nivel mas bajo pero esto puede desactivarse:
Plotear_Arbol(objeto,Nodos.Finales.Abajo = FALSE)
Puede anularse la vinculación del tamaño de los nodos finales a la cantidad de observaciones contenidas en cada uno:
Plotear_Arbol(objeto,Nodos.Finales.Size = FALSE)
Pueden utilizarse iconos en vez de las figuras tradicionales (El catalogo puede encontrarse en: https://fontawesome.com/icons)
Plotear_Arbol(objeto,Icono = "f0c0")
A continuación el código completo para crear la función:
Plotear_Arbol<-function(objeto,
colores=c("red","yellow","darkgreen"),
Nodos.Finales.Abajo=TRUE,
Nodos.Finales.Size=TRUE,
Icono=NULL) {
require(partykit)
require(stringr)
require(dplyr)
require(data.table)
require(visNetwork)
require(grDevices)
## extraer reglas
partykit:::.list.rules.party(objeto)->reglas
## partirlas por "&"
str_split(reglas," & ")->reglas2
## Crear data.table con todas las particiones
lapply(reglas2,function(x) {
regla_completa<-""
for (i in 1:length(x)) {
regla_completa[i]<-paste(x[1:i],collapse ="|")
}
data.table(
Regla=x,
Nivel=1:length(x),
Identificacion=regla_completa,
Identificacion_Ant=c("Inicio",regla_completa[-length(x)])
)
})->reglas2
for (i in 1:length(reglas2)) {reglas2[[i]]$Nodo_Orig<-names(reglas)[i]}
reglas3<-rbindlist(reglas2)
## desarmar la particion entre variable, valor y operador
rbindlist(lapply(strsplit(reglas3$Regla, " > | <= | %in% "),function(x) {
data.table(Variable=x[1],Valor=x[2])
}))->composicion
composicion$Valor<-str_remove_all(composicion$Valor,"%in%")
composicion$Valor<-str_remove_all(composicion$Valor,"c\\(")
composicion$Valor<-str_remove_all(composicion$Valor,"\\)")
composicion$Valor<-unlist(lapply(str_split(composicion$Valor,","),function(x) {
paste(x[!x %in% c(" \"NA\"","\"NA\"")],collapse = "\n")
}))
composicion$Operador<-ifelse(str_detect(reglas3$Regla,"<="),"<=",
ifelse(str_detect(reglas3$Regla,">"),">","%in%"))
cbind(reglas3,composicion)->reglas4
## identifico las reglas del registro y su antecesor
data.table(
Identificacion=c("Inicio",unique(reglas4$Identificacion))
)->referencias
referencias$ID<-1:nrow(referencias)
reglas4 %>% left_join(referencias,by="Identificacion")->reglas4
names(referencias)<-c("Identificacion_Ant","ID_Prev")
reglas4 %>% left_join(referencias,by="Identificacion_Ant")->reglas4
reglas4$Identificacion_Ant<-NULL
reglas4$Identificacion<-NULL
reglas4$Regla<-NULL
## identifico los nodos finales
reglas4 %>% group_by(Nodo_Orig) %>% summarise(MaxNivel=max(Nivel)) %>%
right_join(reglas4,by="Nodo_Orig")->reglas4
reglas4$EsNodoFinal<-ifelse(reglas4$Nivel==reglas4$MaxNivel,1,0)
## calculo sus metricas
data.table(
Nodo_Orig=predict(objeto,type="node"),
Estimacion=predict(objeto,type="response"),
Cantidad=1
) %>% group_by(Nodo_Orig) %>% summarise(Estimacion=max(Estimacion),
Cantidad=sum(Cantidad))->ref_nodos
## le calculo las metricas a todos los nodos
reglas4$Nodo_Orig<-as.numeric(reglas4$Nodo_Orig)
ref_nodos %>% right_join(reglas4,by="Nodo_Orig")->reglas5
## construyo la tabla de nodos previos
reglas5 %>% select(ID_Prev,Variable,Nivel)->nodos
nodos[!duplicated(nodos),]->nodos
names(nodos)<-c("id","label","level")
nodos$shape<-"dot"
nodos$value<-10
## construyo la tabla de nodos finales
reglas5 %>% filter(EsNodoFinal==1) %>% select(Nodo_Orig,ID,Variable,Nivel)->finales
finales$Nodo_Orig<-as.numeric(finales$Nodo_Orig)
right_join(ref_nodos,finales,by="Nodo_Orig")->finales
finales[!duplicated(finales),]->finales
## unifico ambas tablas
round(15*finales$Cantidad/max(finales$Cantidad),0)->tamanio
data.table(
id=finales$ID,
label=paste("Cant:",finales$Cantidad,"\n Est:",round(finales$Estimacion,2)),
level=finales$Nivel+1,
shape="square",
value=ifelse(tamanio==0,1,tamanio)
) %>% rbind(nodos)->nodos_todos
## contruyo la tabla de edges
reglas4 %>% select(ID_Prev,ID,Valor,Operador)->uniones
uniones[!duplicated(uniones),]->uniones
edges <- data.frame(from = uniones$ID_Prev,
to = uniones$ID,
label=paste(ifelse(uniones$Operador=="%in%",
"",uniones$Operador),uniones$Valor)
)
## le agrego la estimacion a todos los nodos
nodos_todos$ID=nodos_todos$id
nodos_todos %>% left_join(
reglas5 %>%
filter(!duplicated(reglas5 %>%
select(ID,Nodo_Orig))) %>%
select(ID,Nodo_Orig)
,"ID") %>% left_join(ref_nodos,"Nodo_Orig") %>%
group_by(ID) %>%
summarise(Estimacion_Calculada=sum(Estimacion*Cantidad)/sum(Cantidad),
Cantidad_Calculada=sum(Cantidad))->Estimacion_calculada
Estimacion_calculada$Cantidad_Calculada[Estimacion_calculada$ID==1]<-sum(ref_nodos$Cantidad)
Estimacion_calculada$Estimacion_Calculada[Estimacion_calculada$ID==1]<-
sum(ref_nodos$Estimacion*ref_nodos$Cantidad)/sum(ref_nodos$Cantidad)
nodos_todos %>% left_join(Estimacion_calculada,by="ID")->nodos_todos2
rbPal <- colorRampPalette(colores)
nodos_todos2$color<-rbPal(nrow(nodos_todos2))[as.numeric(cut(1-nodos_todos2$Estimacion_Calculada,
breaks = nrow(nodos_todos2)))]
## ajustes menores
nodos_todos2$shadow = TRUE
nodos_todos2$title=paste0("<p><b>ID:",nodos_todos2$id,"</b><br> Est:",
format(round(nodos_todos2$Estimacion_Calculada,4),big.mark = ".",decimal.mark = ","),
"<br> Cant:",
format(nodos_todos2$Cantidad_Calculada,big.mark = ".",decimal.mark = ","),
"</p>")
## todos los nodos finales al mismo nivel
if (Nodos.Finales.Abajo) {
nodos_todos2$level<-ifelse(nodos_todos2$shape!="dot",
max(nodos_todos2$level),nodos_todos2$level)
}
if (!Nodos.Finales.Size) {
nodos_todos2$value<-15
}
nodos_todos2$group<-nodos_todos2$color
## construyo la red
## con iconos
if (!is.null(Icono)) {
nodos_todos2$shape<-NULL
grafo<-visNetwork(nodos_todos2, edges,width = "100%")%>%
visHierarchicalLayout()
for (i in unique(nodos_todos2$group)) {
grafo<-grafo %>%
visGroups(groupname = i, shape = "icon",
icon = list(code = Icono, color=i))
}
grafo<-grafo %>%
addFontAwesome()
} else { ## sin iconos
grafo<-visNetwork(nodos_todos2, edges,width = "100%")%>%
visHierarchicalLayout()
}
return(grafo)
}
Al final del día, una buena visualización requiere una tabla de soporte. La función Tabla_Arbol expone de forma simple los cortes del árbol y permite compartir estas ideas de forma más práctica con clientes y otras áreas:
Tabla_Arbol(objeto)
Condicion.1 | Condicion.2 | Condicion.3 | Condicion.4 | Condicion.5 | Estimacion | Cantidad |
carat <= 0.99 | carat <= 0.62 | carat <= 0.45 | x <= 4.62 | clarity in "I1", "SI1", "VS1", "VVS1" | 634.260 | 5428 |
clarity in "SI2", "VS2", "VVS2", "IF" | 756.302 | 6452 | ||||
x > 4.62 | 960.287 | 5409 | ||||
carat > 0.45 | 1675.049 | 7498 | ||||
carat > 0.62 | y <= 5.83 | 2588.292 | 5038 | |||
y > 5.83 | 3527.311 | 5055 | ||||
carat > 0.99 | carat <= 1.49 | clarity in "I1", "SI1", "VS1", "VVS1" | 4923.082 | 7033 | ||
clarity in "SI2", "VS2", "VVS2", "IF" | 7617.413 | 5792 | ||||
carat > 1.49 | 12260.564 | 6235 |
Tabla_Arbol<-function(objeto) {
require(flextable)
require(partykit)
require(dbplyr)
require(stringr)
require(data.table)
## extraer reglas
partykit:::.list.rules.party(objeto)->reglas
## partirlas por "&"
str_split(reglas," & ")->reglas2
## Crear data.table con todas las particiones
lapply(reglas2,function(x) {
regla_completa<-""
for (i in 1:length(x)) {
regla_completa[i]<-paste(x[1:i],collapse ="|")
}
data.table(
Regla=x,
Nivel=1:length(x),
Identificacion=regla_completa,
Identificacion_Ant=c("Inicio",regla_completa[-length(x)])
)
})->reglas2
for (i in 1:length(reglas2)) {reglas2[[i]]$Nodo_Orig<-names(reglas)[i]}
reglas3<-rbindlist(reglas2)
## desarmar la particion entre variable, valor y operador
rbindlist(lapply(strsplit(reglas3$Regla, " > | <= | %in% "),function(x) {
data.table(Variable=x[1],Valor=x[2])
}))->composicion
composicion$Valor<-str_remove_all(composicion$Valor,"%in%")
composicion$Valor<-str_remove_all(composicion$Valor,"c\\(")
composicion$Valor<-str_remove_all(composicion$Valor,"\\)")
composicion$Valor<-unlist(lapply(str_split(composicion$Valor,","),function(x) {
paste(x[!x %in% c(" \"NA\"","\"NA\"")],collapse = ",")
}))
composicion$Operador<-ifelse(str_detect(reglas3$Regla,"<="),"<=",
ifelse(str_detect(reglas3$Regla,">"),">","in"))
cbind(reglas3,composicion)->reglas4
reglas4$Condicion<-paste(reglas4$Variable,reglas4$Operador,reglas4$Valor)
reglas4 %>% select(Nodo_Orig,Nivel,Condicion) %>%
reshape(timevar="Nivel",idvar="Nodo_Orig",direction="wide")->resumen
## calculo sus metricas
data.table(
Nodo_Orig=predict(objeto,type="node"),
Estimacion=predict(objeto,type="response"),
Cantidad=1
) %>% group_by(Nodo_Orig) %>% summarise(Estimacion=max(Estimacion),
Cantidad=sum(Cantidad))->ref_nodos
resumen$Nodo_Orig<-as.numeric(as.character(resumen$Nodo_Orig))
resumen %>% left_join(ref_nodos,"Nodo_Orig")->resumen2
resumen2<-resumen2 %>% arrange(Nodo_Orig)
resumen2$Cantidad<-as.integer(resumen2$Cantidad)
flextable(resumen2[,-1]) %>% theme_box()->hh
hh1<-merge_v(hh, j = c(names(resumen2)[-c(1,ncol(resumen2)-1,ncol(resumen2))]))
return(hh1)
}