Se comienza el análisis propuesto cargando los paquetes necesarios, leyendo el conjunto de datos de entrenamiento y mostrando la estructura del mismo a fin de conocer los atributos sobre los que se van a trabajar.
# Borrar variables del ambiente
rm(list = objects())
# Carga de paquetes necesarios para hacer los gráficos
require(Cairo)
require(caret)
require(cowplot)
require(GGally)
require(ggpubr)
require(ggrepel)
require(highcharter)
require(Hmisc)
require(knitr)
require(magrittr)
require(modelr)
require(pROC)
require(tidyverse)
# Uso de Cairo para renderizar los gráficos.
options(bitmapType = "cairo")
# Cargar conjunto de datos de sobrevivientes del Titanic
datos_train <- readr::read_csv(file = "titanic_complete_train.csv")
# Mostrar estructura
head(datos_train)
## # A tibble: 6 x 12
## PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare
## <dbl> <dbl> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <chr> <dbl>
## 1 1 0 3 Brau… male 22 1 0 A/5 2… 7.25
## 2 2 1 1 Cumi… fema… 38 1 0 PC 17… 71.3
## 3 3 1 3 Heik… fema… 26 0 0 STON/… 7.92
## 4 4 1 1 Futr… fema… 35 1 0 113803 53.1
## 5 5 0 3 Alle… male 35 0 0 373450 8.05
## 6 6 0 3 Mora… male 26.5 0 0 330877 8.46
## # … with 2 more variables: Cabin <chr>, Embarked <chr>
Luego se selecciona un subconjunto de features, algunos de los cuales son categóricos y son transformados a factor de R:
datos_train %<>%
# Seleccionar PassengerId, Survived, Pclass, Sex, Age, SibSp, Parch, Fare y Embarked
dplyr::select(PassengerId, Survived, Pclass, Sex, Age, SibSp, Parch, Fare, Embarked) %>%
# Transformar Survived, Pclass, Sex y Embarked a factor
dplyr::mutate(Survived = as.factor(Survived),
Pclass = as.factor(Pclass),
Sex = as.factor(Sex),
Embarked = as.factor(Embarked))
A continuación se efectúa un gráfico con ggpairs para observar las asociaciones entre variables incluyendo la clase. Lo primero que se destaca es la gran cantidad de fallecidos correspondientes a pasajeros de sexo masculino y que compraron pasajes de tercera clase. Esto indicaría que son variables con alto poder predictivo. También se observa que los pasajeros fallecidos tienen una mediana de edad un poco mayor que los sobrevivientes. Por último, también se observa que hay una cantidad levemente mayor de sobrevivientes entre los pasajeros que pagaron tarifas más altas.
GGally::ggpairs(data = datos_train,
mapping = ggplot2::aes(col = Survived),
columns = which(colnames(datos_train) %in%
c("Pclass", "Sex", "Age", "Fare"))) +
ggplot2::theme_bw() +
ggplot2::theme(
legend.position = 'bottom',
axis.text.x = element_text(angle = 90, hjust = 1)
)
Para finalizar la etapa de preparación de los datos, tomamos el conjunto de datos de entrenamiento y separamos 70% de las observaciones para entrenamiento y el 30% restante con fines de validación. El gráfico interactivo de barras que se muestra a continuación fue confeccionado para corroborar que se mantenga la proporción de clases a predecir en los conjuntos de datos resultantes. Pueden visualizarse los porcentajes de cada clase deslizando el puntero del mouse sobre cada una de las barras.
# Dividir el conjunto de entrenamiento en 70% para entrenamiento y 30% para validacion.
# Se define una semilla para que el ejemplo sea reproducible.
set.seed(0)
entrenamiento_validacion <- datos_train %>%
modelr::resample_partition(c(train = 0.7, test = 0.3))
datos_entrenamiento <- entrenamiento_validacion$train %>%
as_tibble()
datos_validacion <- entrenamiento_validacion$test %>%
as_tibble()
# Analizar distribución de clase para verificar que se preserven las proporciones
datos.grafico.proporcion.clase <- dplyr::bind_rows(
datos_train %>% dplyr::mutate(Conjunto = "Original de entrenamiento"),
datos_entrenamiento %>% dplyr::mutate(Conjunto = "Entrenamiento (70%)"),
datos_validacion %>% dplyr::mutate(Conjunto = "Validación (30%)")) %>%
dplyr::mutate(Survived = dplyr::case_when(Survived == 0 ~ "No sobrevivió",
TRUE ~ "Sobrevivió"),
Conjunto = factor(Conjunto,
levels = c("Original de entrenamiento", "Entrenamiento (70%)", "Validación (30%)"))) %>%
dplyr::group_by(Conjunto) %>%
dplyr::mutate(Total = dplyr::n()) %>%
dplyr::group_by(Conjunto, Survived, Total) %>%
dplyr::summarise(Cantidad = dplyr::n()) %>%
dplyr::mutate(Porcentaje = 100 * Cantidad / Total)
suppressWarnings(highcharter::highchart() %>%
highcharter::hc_add_series(data = datos.grafico.proporcion.clase, type = "column",
mapping = highcharter::hcaes(x = Conjunto, y = Porcentaje, group = Survived)) %>%
highcharter::hc_colors(c("#e31a1c", "#1f78b4")) %>%
highcharter::hc_xAxis(title = list(text = "Conjunto de datos"), type = "category",
categories = unique(datos.grafico.proporcion.clase$Conjunto)) %>%
highcharter::hc_yAxis(title = list(text = "Porcentaje de datos en la clase"),
min = 0, max = 100) %>%
highcharter::hc_tooltip(valueDecimals = 2, valueSuffix = ' %') %>%
highcharter::hc_title(text = "Distribución de clases por conjunto de datos"))
Se realiza un primer modelo de regresión logística de acuerdo a los atributos propuestos: Pclass, Sex y Age. A continuación se presentan los coeficientes de cada una de las covariables. Para el caso de las variables categóricas Pclass y Sex se consideraron como categorías basales las correspondientes a primera clase y sexo femenino respectivamente.
# Se realiza el primer modelo con las variables Pclass, Sex y Age
modelo.1 <- glm(formula = Survived ~ Pclass + Sex + Age,
family = 'binomial', data = datos_entrenamiento)
knitr::kable(
broom::tidy(modelo.1)
)
| term | estimate | std.error | statistic | p.value |
|---|---|---|---|---|
| (Intercept) | 3.7362104 | 0.4665808 | 8.007639 | 0.0000000 |
| Pclass2 | -1.1575912 | 0.3168201 | -3.653781 | 0.0002584 |
| Pclass3 | -2.6680361 | 0.3159941 | -8.443309 | 0.0000000 |
| Sexmale | -2.5467037 | 0.2254785 | -11.294665 | 0.0000000 |
| Age | -0.0376345 | 0.0092944 | -4.049163 | 0.0000514 |
De la tabla anterior observamos que todos los coeficientes de las covariables propuestas son negativos y altamente significativos. Que un coeficiente sea negativo, indica que un incremento en el valor de la covariable disminuye la probabilidad de supervivencia esperada del pasajero ceteris paribus. Para el caso de la edad, esto es consistente con el análisis previo, donde se había observado que la edad de los fallecidos tenía una mediana mayor que la de los sobrevivientes.
Para el caso de las variables categórias, se observa que ser de sexo masculino disminuye la probabilidad esperada de supervivencia respecto del sexo femenino. Lo mismo sucede cuando empeora la categoría en la que viaja del pasajero. Es decir, los de segunda categoría tienen menos probabilidad esperada de superviviencia que los de primera y lo mismo ocurre con los de tercera respecto de los de segunda. Hasta aquí todos los hallazgos han sido consistentes con lo observado en el gráfico de ggpairs.
Finalmente, y a modo de ejemplo, se muestra la probabilidad esperada de supervivencia para dos pasajeros con las siguientes características:
# Ahora se reponde a la pregunta: Quien tiene mas chances de sobrevivir?
data.rose <- data.frame(Nombre = "Rose",
Pclass = factor(1, levels = levels(datos_entrenamiento$Pclass)),
Sex = factor("female", levels = levels(datos_entrenamiento$Sex)),
Age = 17)
prob.rose <- predict(object = modelo.1, newdata = data.rose, type = "response")
data.jack <- data.frame(Nombre = "Jack",
Pclass = factor(3, levels = levels(datos_entrenamiento$Pclass)),
Sex = factor("male", levels = levels(datos_entrenamiento$Sex)),
Age = 20)
prob.jack <- predict(object = modelo.1, newdata = data.jack, type = "response")
knitr::kable(
rbind(dplyr::mutate(data.rose, Prob = prob.rose), dplyr::mutate(data.jack, Prob = prob.jack)),
digits = 3, col.names = c("Pasajero", "Clase", "Sexo", "Edad", "Prob. esperada de sobrevivir")
)
| Pasajero | Clase | Sexo | Edad | Prob. esperada de sobrevivir |
|---|---|---|---|---|
| Rose | 1 | female | 17 | 0.957 |
| Jack | 3 | male | 20 | 0.097 |
Se observa que la probabilidad de supervivencia esperada de Rose es prácticamente 10 veces mayor que la de Jack, lo cual es razonable debido principalmente a las diferencias de sexo y clase en la que viajan ambos.
Se generan otros 3 nuevos modelos, los cuales incluyen las siguientes covariables:
# Se defines las formulas para los 3 modelos a analizar
formulas.modelos <- modelr::formulas(.response = ~Survived,
sexo_tarifa_edad = ~Sex + Fare + Age,
tarifa_lugar = ~Fare + Embarked,
sexo_clase_lugar_edad = ~ Sex + Pclass + Embarked + Age)
# Generar modelos en base a las formulas
modelos <- dplyr::tibble(formulas.modelos) %>%
# Generar features con el modelo y la expresion del mismo
dplyr::mutate(Expresion = paste(formulas.modelos),
Nombre = names(formulas.modelos),
Modelo = purrr::map(formulas.modelos, ~glm(.,family = 'binomial', data = datos_entrenamiento))) %>%
# Seleccionar Expresion y Modelo
dplyr::select(Expresion, Nombre, Modelo) %>%
# Agregar modelo inicial
dplyr::bind_rows(tibble::tibble(Expresion = "Survived ~ Pclass + Sex + Age", Modelo = list(modelo.1)))
# Mostrar los modelos ordenados por el deviance de cada uno de ellos
modelos %>%
dplyr::mutate(glance = purrr::map(Modelo, broom::glance)) %>%
tidyr::unnest(glance) %>%
dplyr::select(Expresion, deviance, null.deviance, logLik, AIC, BIC) %>%
dplyr::arrange(deviance) %>%
knitr::kable(col.names = c("Modelo", "Deviance", "Null deviance", "Log-likelihood", "AIC", "BIC"),
digits = 2)
| Modelo | Deviance | Null deviance | Log-likelihood | AIC | BIC |
|---|---|---|---|---|---|
| Survived ~ Sex + Pclass + Embarked + Age | 549.65 | 823.80 | -274.82 | 563.65 | 594.66 |
| Survived ~ Pclass + Sex + Age | 554.12 | 827.68 | -277.06 | 564.12 | 586.30 |
| Survived ~ Sex + Fare + Age | 618.17 | 827.68 | -309.08 | 626.17 | 643.91 |
| Survived ~ Fare + Embarked | 765.88 | 823.80 | -382.94 | 773.88 | 791.61 |
En la tabla previa se muestran datos de bondad de ajuste para los 3 modelos anteriores y el modelo inicial. En particular, la tabla está ordenada por deviance en forma creciente. La deviance es la generalización de la idea de la suma de los cuadrados de los residuos que se aplica en mínimos cuadrados ordinarios para los casos donde el ajuste se realiza por el método de máxima verosimilitud. Cuanto mayor sea la diferencia entre la null deviance y la deviance del modelo, mayor será el porcentaje de variabilidad explicada por éste. En este caso, el modelo que capta el mayor porcentaje de variabilidad es el correspondiente a las covariables sexo, clase de ticket, lugar de embarcación y edad.
En esta sección se realizará una primera evaluación de la performance del mejor modelo (basado en sexo, clase de ticket, lugar de embarcación y edad) sobre el conjunto de entrenamiento. El modelo predice la probabilidad esperada de supervivencia. Sin embargo, a fin de poder evaluar el poder de predicción, es necesario transformar esas probabilidades en eventos esperados de supervivencia. Es razonable pensar que una alta probabilidad está asociada a un evento de supervivencia, mientras que una baja probabilidad está asociada a un evento de fallecimiento.
Una forma de medir la concordancia entre los eventos esperados y los predichos es a través de la curva ROC. Esta curva se construye generando múltiples escenarios de eventos predichos a partir de distintos puntos de corte de la probabilidad de supervivencia. Probabilidades mayores al punto de corte se considerarán eventos de supervivencia y probabilidades menores, eventos de deceso. Luego, para cada escenario definido por un punto de corte, se calcularán dos indicadores de concordancia: la tasa de verdaderos positivos (cuántos eventos de supervivencia fueron realmente clasificados como tales) y la tasa de falsos positivos (cuántos de los eventos de deceso fueron erróneamente clasificados como de supervivencia).
Si la predicción fuera hecha completamente al azar, sería de esperar que la tasa de verderos positivos (TPR) sea igual a la de falsos positivos (FPR). Esta asignación nos produciría una curva ROC igual a la función identidad. En un modelo con mayor capacidad predictiva, es de esperar que para una cierta tasa de falsos positivos, la tasa de verdaderos positivos sea mayor. Es por eso que un buen modelo presenta una curva ROC con valores mayores en las ordenadas que la ROC para un modelo aleatorio. La forma de cuantificar cuán mayor es la ROC de un modelo en comparación con la ROC de un modelo aleatorio es mediante al área bajo la curva ROC o ROC AUC. En el caso de un modelo alreatorio cuya curva ROC es la función identidad, la ROC AUC vale 0.5. Para el modelo perfecto, la ROC AUC vale 1.
# Generamos la curva ROC y calculamos ROC AUC para el modelo elegido (sexo_clase_lugar_edad)
prediccion.mejor.modelo <- modelos %>%
dplyr::filter(Nombre == "sexo_clase_lugar_edad") %>%
dplyr::mutate(pred = purrr::map(Modelo, broom::augment, type.predict = "response")) %>%
tidyr::unnest(pred)
# Calcular ROC
roc.mejor.modelo <- pROC::roc(response = prediccion.mejor.modelo$Survived, predictor = prediccion.mejor.modelo$.fitted)
En el siguiente gráfico se observa a la curva ROC del modelo completamente aleatorio con líneas punteadas. La línea sólida que está por encima es la curva ROC del mejor modelo de los probados anteriormente. Se observa que para este modelo, la ROC AUC es de 0.858, lo cual a priori indicaría que es un modelo bastante bueno considerando la simpleza del mismo.
# Graficar ROC
pROC::ggroc(roc.mejor.modelo, size = 1, legacy.axes = TRUE) +
ggplot2::geom_abline(slope = 1, intercept = 0, linetype = 'dashed') +
ggplot2::geom_label(data = data.frame(x = 0.9, y = 0.1),
mapping = ggplot2::aes(x = x, y = y, label = sprintf("AUC: %.3f", roc.mejor.modelo$auc))) +
ggplot2::labs(title = 'Curva ROC', subtitle = 'Modelo basado en sexo, clase, lugar de embarcación y edad',
y = 'Tasa de verdaderos positivos (TPR)', x = "Tasa de falsos positivos (FPR)") +
ggplot2::theme_bw() +
ggplot2::theme(
plot.title = ggplot2::element_text(hjust = 0.5),
plot.subtitle = ggplot2::element_text(hjust = 0.5)
)
En este próximo gráfico observamos dos superficies correspondientes a los eventos reales de supervivencia y fallecimiento. En el eje de las ordenadas se muestran las posibles probabilidades de predicción que serán candidatas a tomarse como puntos de corte. Para cada superficie, se observa la densidad de casos del evento correspondiente en relación a la probabilidad de supervivencia esperada. Se observa que los casos de supervivencia tienen mayor densidad de casos en probabilidades altas, mientras que los eventos de deceso tiene mucha mayor densidad en las probabilidades bajas.
Esto es consistente con el planteo inicial que se hizo al comienzo de la sección. Dado que los casos de decesos tienen mucha mayor densidad en probabilidades bajas que los de supervivencia en probabilidades altas, sería de esperar - a priori - que el punto de corte para separar eventos de deceso y supervivencia esté en valores más cercanos a 0 que a 1. A continuación se realizará un análisis más exhaustivo del punto de corte.
# Violin plot
ggplot2::ggplot(data = prediccion.mejor.modelo,
mapping = ggplot2::aes(x = Survived, y = .fitted, group = Survived, fill = Survived)) +
ggplot2::geom_violin() +
ggplot2::scale_fill_manual(values = c("0" = "tomato", "1" = "darkslategray4"),
labels = c("0" = "No sobrevivió", "1" = "Sobrevivió"), name = "") +
ggplot2::labs(title = 'Gráfico de violín', subtitle = 'Modelo basado en sexo, clase, lugar de embarcación y edad',
y = 'Probabilidad predicha', x = "") +
ggplot2::theme_bw() +
ggplot2::theme(
legend.position = 'bottom',
plot.title = ggplot2::element_text(hjust = 0.5),
plot.subtitle = ggplot2::element_text(hjust = 0.5),
axis.ticks.x = ggplot2::element_blank(),
axis.text.x = ggplot2::element_blank()
)
A continuación se utilizarán las 4 métricas propuestas para hacer una selección adecuada del punto de corte. En adelante, se denominará predicción positiva a la predicción asociada a un evento de supervivencia. Por el contrario, una predicción de deceso se denominará predicción negativa. Luego, los vedaderos positivos (TP) y verdaderos negativos (TN) serán aquellos eventos de supervivencia y deceso clasificados como tales, respectivamente. Finalmente, los falsos positivos (FP) y falsos negativos (FN) serán aquellos eventos de supervivencia y deceso clasificados erróneamente como el evento contrario, respectivamente.
Para cada punto de corte evaluado entre 0 y 1, se calcularán los valores de cada una de las métricas indicadas. El valor ideal para todas estas métricas es 1, siendo 0 el peor indicador. Se deberá buscar un punto de corte haciendo un trade-off entre todos los valores de las métricas.
# Prediccion utilizando conjunto de validacion
prediccion.validacion <- modelos %>%
dplyr::filter(Nombre == "sexo_clase_lugar_edad") %>%
dplyr::mutate(pred = purrr::map(Modelo, predict, newdata = dplyr::select(datos_validacion, -Survived),
type = "response")) %>%
tidyr::unnest(pred) %>%
dplyr::pull(pred)
clases.validacion <- datos_validacion %>%
dplyr::mutate(Clase = as.integer(as.character(Survived))) %>%
dplyr::pull(Clase)
# Definir metricas de prediccion en funcion de un umbral de corte
MetricasPrediccion <- function(probabilidades, clase, umbral.corte) {
predicciones <- ifelse(probabilidades > umbral.corte, 1, 0)
matriz.confusion <- caret::confusionMatrix(
table(as.character(predicciones), as.character(clase)),
positive = "1")
broom::tidy(matriz.confusion) %>%
dplyr::select(term, estimate) %>%
dplyr::filter(term %in% c('accuracy', 'recall', 'specificity', 'precision')) %>%
dplyr::mutate(umbral = umbral.corte)
}
# Calcular metricas en funcion del umbral
puntos.corte <- list(p1 = 0, p2 = 0)
distancias <- list(p1 = 1, p2 = 1)
metricas <- purrr::map_dfr(
.x = seq(from = 0.03, to = 0.96, by = 0.001),
.f = function(umbral.corte) {
metricas.umbral <- MetricasPrediccion(prediccion.validacion, clases.validacion,
umbral.corte = umbral.corte)
recall.umbral <- dplyr::filter(metricas.umbral, term == "recall") %>%
dplyr::pull(estimate)
precision.umbral <- dplyr::filter(metricas.umbral, term == "precision") %>%
dplyr::pull(estimate)
accuracy.umbral <- dplyr::filter(metricas.umbral, term == "accuracy") %>%
dplyr::pull(estimate)
distancia.prec.rec <- abs(recall.umbral - precision.umbral)
distancia.acc.rec <- abs(recall.umbral - accuracy.umbral)
if (distancia.prec.rec < distancias$p1) {
puntos.corte$p1 <<- umbral.corte
distancias$p1 <<- distancia.prec.rec
}
if (distancia.acc.rec < distancias$p2) {
puntos.corte$p2 <<- umbral.corte
distancias$p2 <<- distancia.acc.rec
}
return (metricas.umbral)
}
)
# Graficamos metricas en funcion del umbral
ggplot2::ggplot(data = metricas) +
ggplot2::geom_line(mapping = ggplot2::aes(x = umbral, y = estimate, group = term, color = term), size = 0.7) +
ggplot2::geom_vline(xintercept = puntos.corte$p1, linetype = "dotted") +
ggplot2::geom_vline(xintercept = puntos.corte$p2, linetype = "dotted") +
ggrepel::geom_label_repel(data = data.frame(x = puntos.corte$p1, y = 0),
mapping = ggplot2::aes(x = x, y = y, label = sprintf("P1 = %.3f", x)),
colour = 'black', nudge_x = 0.1, nudge_y = 0.1) +
ggrepel::geom_label_repel(data = data.frame(x = puntos.corte$p2, y = 0),
mapping = ggplot2::aes(x = x, y = y, label = sprintf("P2 = %.3f", x)),
colour = 'black', nudge_x = -0.1, nudge_y = 0.1) +
ggplot2::labs(title = 'Accuracy, Sensitivity (o Recall), Specificity y Precision',
subtitle = 'Modelo basado en sexo, clase, lugar de embarcación y edad',
col = "Métrica", x = "Punto de corte", y = "Valor de la métrica") +
ggplot2::theme_bw() +
ggplot2::theme(
legend.position = 'bottom',
plot.title = ggplot2::element_text(hjust = 0.5),
plot.subtitle = ggplot2::element_text(hjust = 0.5)
)
Del gráfico anterior se observa que conforme aumenta el punto de corte, también aumentan las métricas specificity y precision por efecto de la fuerte disminución de los falsos positivos. Del mismo modo, al aumentar el punto de corte, también aumentan fuertemente los falsos negativos, haciendo que la métrica recall disminuya. Finalmente, se observa que el accuracy no tiene un comportamiento monótono, lo cual es interesante dado que es posible encontrar un máximo. Esta métrica, al considerar tanto los falsos positivos como los falsos negativos, depende de cuánto peso tengan ambas cantidades las cuales tienen comportamientos opuestos al cambiar el punto de corte.
Como primera aproximación, se podría argumentar que un buen punto de corte es P1, dado que el valor de accuracy es máximo y las demás métricas tienen todas un valor relativamente alto y ninguna es demasiado buena a expensas de que otra sea demasiado mala. Considerar valores mayores a P1 (hasta el punto en que el accuracy comience a disminuir) implicaría aumentar tres de las las métricas a expensas de empeorar la métrica recall. Sin embargo, este enfoque está basado solamente en el valor de las métricas sin considerar las mismas en función de su significado.
Sería bastante razonable suponer que un modelo como este se podría utilizar para predecir el número de supervivientes y en función a dicha cantidad, asignar recursos para una operación de rescate. Sabiendo que los recursos son siempre finitos y en general escasos, se debería hacer una análisis que contemple tal situación. Por tal motivo, no tiene la misma importancia dismuir la cantidad de falsos positivos que la cantidad de falsos negativos. Un falso positivo representa una persona clasificada como superviviente cuando en realidad está fallecida y un falso negativo es otra catalogada como fallecida cuando en realidad ha sobrevivido. Sobre la base del razonamiento inicial, si una persona es clasificada como fallecida, no se asignan recursos para ir al rescate de ésta. Por tal motivo el foco debe ponerse en disminuir los falsos negativos o FN.
Observando la definición de las métricas, se observa que recall es la que mejor penaliza los falsos negativos. Por tal motivo, si solamente nos focalizamos en estas 4 métricas, debemos priorizar un poco más a recall que a las demás (existen otras métricas como F1 o F\(_{beta}\) que consideran precision y recall simultáneamente pero están fuera del foco de este trabajo). Se asume como hipótesis adicional que resultaría imposible considerar el punto de corte 0 a fin de maximizar recall dado que implicaría asignar recursos para el rescate de todos los pasajeros.
Con la consideración anterior y teniendo en cuenta la premisa de no mejorar una métrica a expensas de empeorar mucho todas las demás, elegimos el punto P2 que es muy cercano a la intersección de accuracy, recall y specificity (dado que la selección del umbral está discretizada). En este punto el valor accuracy está bastante cerca del máximo y las otras métricas son levemente menores que para el punto P1. Desde luego, podrían existir otros criterios igualmente válidos para seleccionar algún otro punto de corte.
# MdC para primer punto de corte
predicciones.1 <- ifelse(prediccion.validacion > puntos.corte$p1, 1, 0)
matriz.confusion.1 <- caret::confusionMatrix(
table(as.character(predicciones.1), as.character(clases.validacion), dnn = c("predicciones", "observaciones")),
positive = "1")$table
rownames(matriz.confusion.1) <- c("Fallecidos predichos", "Sobrevivientes predichos")
colnames(matriz.confusion.1) <- c("Fallecidos reales", "Sobrevivientes reales")
# MdC para segundo punto de corte
predicciones.2 <- ifelse(prediccion.validacion > puntos.corte$p2, 1, 0)
matriz.confusion.2 <- caret::confusionMatrix(
table(as.character(predicciones.2), as.character(clases.validacion), dnn = c("predicciones", "observaciones")),
positive = "1")$table
rownames(matriz.confusion.2) <- c("Fallecidos predichos", "Sobrevivientes predichos")
colnames(matriz.confusion.2) <- c("Fallecidos reales", "Sobrevivientes reales")
knitr::kable(matriz.confusion.1)
| Fallecidos reales | Sobrevivientes reales | |
|---|---|---|
| Fallecidos predichos | 137 | 26 |
| Sobrevivientes predichos | 26 | 79 |
knitr::kable(matriz.confusion.2)
| Fallecidos reales | Sobrevivientes reales | |
|---|---|---|
| Fallecidos predichos | 128 | 23 |
| Sobrevivientes predichos | 35 | 82 |
Las tablas previas muestran las matrices de confusión resultantes para cada punto de corte. La suma de los valores de las columnas representan la cantidad real de sobrevivientes y fallecidos. A su vez, la suma de los valores de las filas representan la cantidad predicha de sobrevivientes y fallecidos. Finalmente, cada celda indica la cantidad de personas con su clasificación real y su clasificación de acuerdo al modelo. Lo primero que es importante destacar es que hay muchos más aciertos que errores para ambos puntos de corte. Además, se observa que para el punto de corte P1 hay 26 casos de FN, mientras que para P2 hay 23. Esas son 3 personas que se podrían rescatar a expensas de aumentar en 9 la cantidad de falsos positivos. Más abajo se muestran los valores de las 4 métricas para cada punto de corte.
comparacion.metricas.validacion <- MetricasPrediccion(prediccion.validacion, clases.validacion, puntos.corte$p1) %>%
dplyr::bind_rows(MetricasPrediccion(prediccion.validacion, clases.validacion, puntos.corte$p2)) %>%
tidyr::pivot_wider(names_from = "term", values_from = "estimate") %>%
dplyr::rename("Punto de corte" = umbral, Accuracy = accuracy, Recall = recall, Precision = precision, Specificity = specificity)
knitr::kable(comparacion.metricas.validacion, digits = 3)
| Punto de corte | Accuracy | Specificity | Precision | Recall |
|---|---|---|---|---|
| 0.396 | 0.806 | 0.840 | 0.752 | 0.752 |
| 0.345 | 0.784 | 0.785 | 0.701 | 0.781 |
Por último, se aplica el modelo al conjunto de test. Se leen los datos, se aplican las mismas transformaciones indicadas en la primera sección del documento, y se realiza la predicción. A continuación se muestra la matriz de confusión para el punto de corte P2 y los valores de las 4 métricas para ambos puntos de corte a fin de poder compararlos.
# Cargar conjunto de datos de test y aplicar las mismas transformaciones que para train
datos_test <- readr::read_csv(file = "titanic_complete_test.csv") %>%
# Seleccionar las variables PassengerId, Survived, Pclass, Sex, Age, SibSp,Parch, Fare y Embarked
dplyr::select(PassengerId, Survived, Pclass, Sex, Age, SibSp,Parch, Fare, Embarked) %>%
# Transformar las variables Survived, Pclass, Sex y Embarked a factor
dplyr::mutate(Survived = as.factor(Survived),
Pclass = as.factor(Pclass),
Sex = as.factor(Sex),
Embarked = as.factor(Embarked))
# Aplicar el modelo a conjunto de test
prediccion.test <- modelos %>%
dplyr::filter(Nombre == "sexo_clase_lugar_edad") %>%
dplyr::mutate(pred = purrr::map(Modelo, predict, newdata = dplyr::select(datos_test, -Survived),
type = "response")) %>%
tidyr::unnest(pred) %>%
dplyr::pull(pred)
clases.test <- datos_test %>%
dplyr::mutate(Clase = as.integer(as.character(Survived))) %>%
dplyr::pull(Clase)
# Realizar predicciones en base a mejor al punto de corte P2
predicciones.3 <- ifelse(prediccion.test > puntos.corte$p1, 1, 0)
matriz.confusion.3 <- caret::confusionMatrix(
table(as.character(predicciones.3), as.character(clases.test), dnn = c("predicciones", "observaciones")),
positive = "1")$table
rownames(matriz.confusion.3) <- c("Fallecidos predichos", "Sobrevivientes predichos")
colnames(matriz.confusion.3) <- c("Fallecidos reales", "Sobrevivientes reales")
knitr::kable(matriz.confusion.3)
| Fallecidos reales | Sobrevivientes reales | |
|---|---|---|
| Fallecidos predichos | 199 | 40 |
| Sobrevivientes predichos | 62 | 117 |
Al igual que con el conjunto de validación, se observa que la cantidad de aciertos (TP y TN) es bastante superior a la cantidad de errores. Las métricas son un poco peores que para el conjunto de validación a excepción de recall, que se mantiene bastante parecida (incluso para ambos puntos de corte). También se observa que el valor de recall es mayor para P2 que para P1. Finalmente, también es importante destacar que la diferencia entre los valores de las métricas para ambos puntos es prácticamente la misma que se observa para el conjunto de validación.
comparacion.metricas.test <- MetricasPrediccion(prediccion.test, clases.test, puntos.corte$p1) %>%
dplyr::bind_rows(MetricasPrediccion(prediccion.test, clases.test, puntos.corte$p2)) %>%
tidyr::pivot_wider(names_from = "term", values_from = "estimate") %>%
dplyr::rename("Punto de corte" = umbral, Accuracy = accuracy, Recall = recall, Precision = precision, Specificity = specificity)
knitr::kable(comparacion.metricas.test, digits = 3)
| Punto de corte | Accuracy | Specificity | Precision | Recall |
|---|---|---|---|---|
| 0.396 | 0.756 | 0.762 | 0.654 | 0.745 |
| 0.345 | 0.725 | 0.697 | 0.605 | 0.771 |