1 - Introduccion
En el siguiente trabajo evaluaremos la posibilidad de implementar un modelo de árbol de decisión utilizando una de los métodos de la libraría caret llamado rpart (Partición recursiva para árboles de clasificación, regresión y supervivencia) para los datos suministrados por Kaggle sobre observaciones del Titanic y sus pasajeros.
El data set tiene variables que caracterizan a los pasajeros, estas features junto a la variable objetivo, la de supervivencia, nos sirven para encontrar los parámetros que mejor estimen nuestra variable dependiente.
El set de datos vino separado por Test y Training, es decir, un conjunto de datos para testear el modelo y otro conjunto de datos para entrenarlo. Los datos se pueden descargar de Kaggle.
Para la manipulación de los mismos primero se hace un merge entre ambos data set, una descripcion de cada una de las variables, su distribucion, conteno de valores nulos por variable y una limpieza del data set completo para luego separarlo en entrenamiento y testeo y probar nuestro modelo.
Librerias
library(tidyverse)
library(ggplot2)
library(dplyr)
library(crayon)
library(gridExtra)
library(prettydoc)
library(pacman)
library(plotly)
library(cvms)
library(tibble)
library(DT)
library(caret)
library(rpart.plot)
library(skimr)
library(highcharter)
library(quantmod)Lectura de los Datos
df = read_csv('train.csv')
test = read.csv('test.csv')
values = read.csv('gender_submission.csv')
test$Survived = values$SurvivedMerge
df = rbind(df, test)Dimensión del data set:
## El dataset tiene una dimension de 1309 filas y 12 columnas.
Chequeo de la estructura del archivo importado:
## spec_tbl_df [1,309 x 12] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
## $ PassengerId: num [1:1309] 1 2 3 4 5 6 7 8 9 10 ...
## $ Survived : num [1:1309] 0 1 1 1 0 0 0 0 1 1 ...
## $ Pclass : num [1:1309] 3 1 3 1 3 3 1 3 3 2 ...
## $ Name : chr [1:1309] "Braund, Mr. Owen Harris" "Cumings, Mrs. John Bradley (Florence Briggs Thayer)" "Heikkinen, Miss. Laina" "Futrelle, Mrs. Jacques Heath (Lily May Peel)" ...
## $ Sex : chr [1:1309] "male" "female" "female" "female" ...
## $ Age : num [1:1309] 22 38 26 35 35 NA 54 2 27 14 ...
## $ SibSp : num [1:1309] 1 1 0 1 0 0 0 3 0 1 ...
## $ Parch : num [1:1309] 0 0 0 0 0 0 0 1 2 0 ...
## $ Ticket : chr [1:1309] "A/5 21171" "PC 17599" "STON/O2. 3101282" "113803" ...
## $ Fare : num [1:1309] 7.25 71.28 7.92 53.1 8.05 ...
## $ Cabin : chr [1:1309] NA "C85" NA "C123" ...
## $ Embarked : chr [1:1309] "S" "C" "S" "S" ...
## - attr(*, "spec")=
## .. cols(
## .. PassengerId = col_double(),
## .. Survived = col_double(),
## .. Pclass = col_double(),
## .. Name = col_character(),
## .. Sex = col_character(),
## .. Age = col_double(),
## .. SibSp = col_double(),
## .. Parch = col_double(),
## .. Ticket = col_character(),
## .. Fare = col_double(),
## .. Cabin = col_character(),
## .. Embarked = col_character()
## .. )
## - attr(*, "problems")=<externalptr>
| \[ \textbf{Variable}\] | \[ \textbf{Descripcion}\] | |
|---|---|---|
| Pclass | Clase del Pasaje (1 = 1st; 2 = 2nd; 3 = 3rd) | |
| survival | Sobrevivio (0 = No; 1 = Yes) | |
| name | Nombre | |
| sex | Sexo | |
| age | Edad | |
| sibsp | Numero de hermanas / conyuges a bordo | |
| parch | Numero de padres / niños a bordo | |
| ticket | Numero de Ticket | |
| fare | Tarifa de pasajero (libra esterlina) | |
| cabin | Cabina | |
| embarked | Puerto de embarque (C = Cherbourg; Q = Queenstown; S = Southampton |
2 - Analisis estadístico descriptivo de todo el dataset
## PassengerId Survived Pclass Name
## Min. : 1 Min. :0.0000 Min. :1.000 Length:1309
## 1st Qu.: 328 1st Qu.:0.0000 1st Qu.:2.000 Class :character
## Median : 655 Median :0.0000 Median :3.000 Mode :character
## Mean : 655 Mean :0.3774 Mean :2.295
## 3rd Qu.: 982 3rd Qu.:1.0000 3rd Qu.:3.000
## Max. :1309 Max. :1.0000 Max. :3.000
##
## Sex Age SibSp Parch
## Length:1309 Min. : 0.17 Min. :0.0000 Min. :0.000
## Class :character 1st Qu.:21.00 1st Qu.:0.0000 1st Qu.:0.000
## Mode :character Median :28.00 Median :0.0000 Median :0.000
## Mean :29.88 Mean :0.4989 Mean :0.385
## 3rd Qu.:39.00 3rd Qu.:1.0000 3rd Qu.:0.000
## Max. :80.00 Max. :8.0000 Max. :9.000
## NA's :263
## Ticket Fare Cabin Embarked
## Length:1309 Min. : 0.000 Length:1309 Length:1309
## Class :character 1st Qu.: 7.896 Class :character Class :character
## Mode :character Median : 14.454 Mode :character Mode :character
## Mean : 33.295
## 3rd Qu.: 31.275
## Max. :512.329
## NA's :1
3 - Explorando el tipo de variables
## la columna PassengerId es del tipo double
## la columna Survived es del tipo double
## la columna Pclass es del tipo double
## la columna Name es del tipo character
## la columna Sex es del tipo character
## la columna Age es del tipo double
## la columna SibSp es del tipo double
## la columna Parch es del tipo double
## la columna Ticket es del tipo character
## la columna Fare es del tipo double
## la columna Cabin es del tipo character
## la columna Embarked es del tipo character
##
## Hay 7 variables cuantitativas y 5 variables cualitativas en el set de datos.
##
## Las variables cuantitativas son las siguientes:
## PassengerId, Survived, Pclass, Age, SibSp, Parch, Fare, .
##
## Las variables cualitativas son las siguientes:
## Name, Sex, Ticket, Cabin, Embarked, .
4 - Análisis de datos nulos
Data Frame
## La variable PassengerId tiene 0 registros nulos, el 0 % del total de sus registros.
## La variable Survived tiene 0 registros nulos, el 0 % del total de sus registros.
## La variable Pclass tiene 0 registros nulos, el 0 % del total de sus registros.
## La variable Name tiene 0 registros nulos, el 0 % del total de sus registros.
## La variable Sex tiene 0 registros nulos, el 0 % del total de sus registros.
## La variable Age tiene 263 registros nulos, el 20.09 % del total de sus registros.
## La variable SibSp tiene 0 registros nulos, el 0 % del total de sus registros.
## La variable Parch tiene 0 registros nulos, el 0 % del total de sus registros.
## La variable Ticket tiene 0 registros nulos, el 0 % del total de sus registros.
## La variable Fare tiene 1 registros nulos, el 0.08 % del total de sus registros.
## La variable Cabin tiene 687 registros nulos, el 52.48 % del total de sus registros.
## La variable Embarked tiene 2 registros nulos, el 0.15 % del total de sus registros.
Campo Age Data Frame
El campo Age tiene 20 % de registros nulos. Para completar este campo, ya que estimamos que va a ser de relevancia para nuestro modelo, buscamos hacer una segmentación del data set en función de distintas características para obtener la media de la edad de los pasajeros y luego completar los datos faltantes.
En base a si sobrevivió o no, si es mujer o hombre y la clase del pasaje se estima la media de la edad y se imputa a los valores faltantes que cumplan con las características de la segmentación. Para esta metodología se obtuvieron 12 medias ya que es el máximo de selecciones que se puede obtener en base a las variables escogidas.
df$Age[is.na(filter(df, Survived == 0, Sex == 'male', Pclass == 1)$Age)] = as.numeric(df %>% filter(Survived == 0, Sex == 'male', Pclass == 1) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 1, Sex == 'male', Pclass == 1)$Age)] = as.numeric(df %>% filter(Survived == 1, Sex == 'male', Pclass == 1) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 0, Sex == 'female', Pclass == 1)$Age)] = as.numeric(df %>% filter(Survived == 0, Sex == 'female', Pclass == 1) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 1, Sex == 'female', Pclass == 1)$Age)] = as.numeric(df %>% filter(Survived == 1, Sex == 'female', Pclass == 1) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 0, Sex == 'male', Pclass == 2)$Age)] = as.numeric(df %>% filter(Survived == 0, Sex == 'male', Pclass == 2) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 1, Sex == 'male', Pclass == 2)$Age)] = as.numeric(df %>% filter(Survived == 1, Sex == 'male', Pclass == 2) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 0, Sex == 'female', Pclass == 2)$Age)] = as.numeric(df %>% filter(Survived == 0, Sex == 'female', Pclass == 2) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 1, Sex == 'female', Pclass == 2)$Age)] = as.numeric(df %>% filter(Survived == 1, Sex == 'female', Pclass == 2) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 0, Sex == 'male', Pclass == 3)$Age)] = as.numeric(df %>% filter(Survived == 0, Sex == 'male', Pclass == 3) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 1, Sex == 'male', Pclass == 3)$Age)] = as.numeric(df %>% filter(Survived == 1, Sex == 'male', Pclass == 3) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 0, Sex == 'female', Pclass == 3)$Age)] = as.numeric(df %>% filter(Survived == 0, Sex == 'female', Pclass == 3) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(filter(df, Survived == 1, Sex == 'female', Pclass == 3)$Age)] = as.numeric(df %>% filter(Survived == 1, Sex == 'female', Pclass == 3) %>% summarise(media = mean(Age, na.rm = TRUE)))
df$Age[is.na(df$Age)] = mean(df$Age, na.rm = TRUE)Eliminamos y editamos algunas columnas y filas para limpiar el data set.
df$Cabin = NULL
df$Ticket = NULL
df$Name = NULL
df$PassengerId = NULL
df$Embarked[is.na(df$Embarked)] = names(sort(table(df$Embarked))[length(sort(table(df$Embarked)))])
df$Fare[is.na(df$Fare)] = mean(df$Fare, na.rm = TRUE)Convertimos en factor nuestra variable objetivo
df$Survived = as.factor(df$Survived)
5 - Distribucion de las Variables
hchart(density(df$Age), type = 'area', name = 'Edad', color = '#32a6c9') %>%
hc_yAxis(title = list(text = "Distribucion")) %>%
hc_title(text = "Distribucion de la variable Age")hchart(density(df$Fare), type = 'area', name = 'Fare', color = '#32c9a6') %>%
hc_yAxis(title = list(text = "Distribucion")) %>%
hc_title(text = "Distribucion de la variable Fare")df %>% count(Survived) %>% arrange(n) %>%
hchart('column', color = 'green', hcaes(x = Survived, y = n))df %>% count(Embarked) %>% arrange(n) %>%
hchart('column', color = 'red', hcaes(x = Embarked, y = n))df %>% count(Pclass) %>% arrange(n) %>%
hchart('column', hcaes(x = Pclass, y = n))df %>% count(Sex) %>% arrange(n) %>%
hchart('column', hcaes(x = Sex, y = n, color = c("red")))
6 - Árbol de Decisión
Los árboles de decisión utilizan varios algoritmos para dividir un nodo en subnodos. La creación de estos subnodos aumenta la homogeneidad de los mismos, es decir, podemos decir que la pureza del nodo aumenta con respecto a la variable objetivo.
El árbol divide los nodos en todas las variables disponibles y luego selecciona la división que da como resultado la mayoría de los subnodos homogéneos, es decir, maximiza la elección de subnodos homogéneos.
Para comenzar con el modelo seteamos una semilla para que los resultados sean siempre los mismos al ejecutar el modelo, también separamos nuestro conjunto de datos en test y training con un 70 % para entrenar.
set.seed(107)
data_train = sort(sample(nrow(df), nrow(df)*.7))
train = df[data_train,]
test = df[-data_train,]Ejecutamos el modelo utilizando rpart, colocamos inicialmente la variable dependiente Survived y que tome el conjunto del data set entero de entrenamiento, este paso se hace colocando un punto luego del signo ~, esto nos permite tomar todo el data set completo sin tener en cuenta la variable objetivo.
El método que vamos a utilizar es el class ya que contamos con una variable dependiente que es una clase.
El parámetro de complejidad (cp) se utiliza para controlar el tamaño del árbol de decisión y para seleccionar el tamaño de árbol óptimo. Si el costo de agregar otra variable al árbol de decisión desde el nodo actual está por encima del valor de cp, entonces la construcción del árbol no continúa. En definitiva, el CP es un Parámetro que detiene las divisiones de los nodos en un máximo. Para nuestro modelo vamos a programarlo con un valor casi de cero para luego ir recortando los nodos y achicar el árbol e ir mejorando nuestro modelo.
model_arbol = rpart(Survived ~ . , data = train, method = 'class', cp = 0.001)El árbol se encuentra dividido en nodos, donde la primera divison, el nodo principal, es entre el sexo de los pasajeros, se divide entre hombre y mujer, donde la división por hombre se hace por un 65 % de los datos y por mujer con un 35 %, estos valores se encuentran en el gráfico.
Cada Nodo contiene:
No Survived = 0, Survived = 1
Probabilidad de Supervivencia
Porcentaje de observaciones por Nodo
Seteamos el tipo de Árbol que deseamos, el tamaño de los valores y descripciones que aparecen en el árbol y con box.palette creamos una paleta de colores donde el mas claro es mas cerca de no sobrevivir y el rojo mas oscuro mas cerca de sobrevivir.
Una vez observado el Árbol podemos sacar algunas conclusiones:
Si el Sexo es Hombre y la variable Fare es menor a 26, hay 90 % de probabilidades de no sobrevivir.
Si el Sexo es Mujer y la variable Pclass es menor a 3, hay 97 % de probabilidades de sobrevivir.
rpart.plot(model_arbol, type = 4, fallen.leaves = F, cex = 0.62, box.palette=c("#F6C9C9", "#D47676", "#D52727"), branch.lty = 3, shadow.col = "gray")Importancia de las variables
Podemos observar que para el modelo la variable mas relevante es el sexo.
barplot(model_arbol$variable.importance,
main = "Importancia de las variables",
xlab = "Variables",
border = "black",
col = c("#EC8686"))El modelo nos da valores entre 0 y 1, el algoritmo toma como 0 si el output es menor a 0,5 y como 1 si es mayor a 0,5 para poder clasificar la variable.
predict = predict(model_arbol, test, type = 'class')Matriz de Confusión
cm = confusionMatrix(as.factor(predict), as.factor(test$Survived), positive = NULL, dnn = c("Prediction", "Reference"))
cm## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 220 39
## 1 25 109
##
## Accuracy : 0.8372
## 95% CI : (0.7969, 0.8723)
## No Information Rate : 0.6234
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.6466
##
## Mcnemar's Test P-Value : 0.1042
##
## Sensitivity : 0.8980
## Specificity : 0.7365
## Pos Pred Value : 0.8494
## Neg Pred Value : 0.8134
## Prevalence : 0.6234
## Detection Rate : 0.5598
## Detection Prevalence : 0.6590
## Balanced Accuracy : 0.8172
##
## 'Positive' Class : 0
##
cfm = as_tibble(table(tibble("target" = test$Survived,
"prediction" = predict)))
plot_confusion_matrix(cfm,
target_col = "target",
prediction_col = "prediction",
counts_col = "n",
palette = 'Reds') Accuracy:
El acurracy es el porcentaje total de elementos clasificados correctamente, es decir, la suma de los verdaderos positivos y los verdaderos negativos dividido el total.
## El Acurracy del modelo es de: 0.84
7 - Ajuste del Modelo
En base a lo mencionado anteriormente sobre el parámetro de complejidad, para este caso, queremos evaluar un modelo mas simple, por lo tanto vamos a tener que podar el árbol.
Para poder continuar tenemos que encontrar el parametro de complejidad que haga minimo el error de validación cruzada.
cp_2 = model_arbol$cptable[which.min(model_arbol$cptable[,"xerror"]),"CP"]
model_arbol_2 = rpart(Survived ~ . , data = train, method = 'class', cp = cp_2)
predict_2 = predict(model_arbol_2, test, type = 'class')rpart.plot(model_arbol_2, type = 4, fallen.leaves = F, cex = 0.62, box.palette=c("#F6C9C9", "#D47676", "#D52727"), branch.lty = 3, shadow.col = "gray")cm_2 = confusionMatrix(as.factor(predict_2), as.factor(test$Survived), positive = NULL, dnn = c("Prediction", "Reference"))
cm_2## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 220 31
## 1 25 117
##
## Accuracy : 0.8575
## 95% CI : (0.819, 0.8905)
## No Information Rate : 0.6234
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.6941
##
## Mcnemar's Test P-Value : 0.504
##
## Sensitivity : 0.8980
## Specificity : 0.7905
## Pos Pred Value : 0.8765
## Neg Pred Value : 0.8239
## Prevalence : 0.6234
## Detection Rate : 0.5598
## Detection Prevalence : 0.6387
## Balanced Accuracy : 0.8442
##
## 'Positive' Class : 0
##
cfm_2 = as_tibble(table(tibble("target" = test$Survived,
"prediction" = predict_2)))
plot_confusion_matrix(cfm_2,
target_col = "target",
prediction_col = "prediction",
counts_col = "n",
palette = 'Reds')Accuracy:
## El Acurracy del modelo es de: 0.86
## El Acurracy del modelo mejoro en 2 puntos porcentuales.
8 - Conclusiones
En primera instancia utilizamos los datos disponibles para estimar valores faltantes que nos parecieron importantes para el modelo a realizar, se imputaron las edades en base a segmentaciones por distintas variables que caracterizaban a los pasajeros, se eliminaron algunas columnas que no eran relevantes y luego se entreno el modelo con el 70 % de los datos.
Como primer paso, para el modelo de Árbol de Decisión, se utilizo un valor muy bajo de cp para poder observar una mayor cantidad de nodos y recalibar el mismo en base a el parámetro de complejidad que hace mínimo el error de validación cruzada.
Este método mejoro nuestras estimaciones ya que incremento el Acurracy del modelo, dándonos mejores resultados.