Muchos modelos de aprendizaje automático complejos (redes neuronales, SVM, etc.), aunque pueden llegar a aportar muy buenas predicciones, son tratados como “cajas negras” al no aportarnos de forma directa una explicación de sus mecanismos internos o cómo han llegado a sus predicciones, es decir, modelos que no pueden entenderse a partir de los parámetros que aportan.
A menudo, una mayor precisión se produce a expensas de la pérdida interpretación, pero esta resulta crucial en numerosos casos, en los que es importante poder explicar o entender no solo el “qué”, sino el “por qué”. En este contexto, la interpretabilidad es muy importante, ya que nos puede proporcionar una explicación entendible de las predicciones de un modelo para ganar mayor entendimiento del problema a tratar, en lugar de aceptar sus resultados sin más.
Los métodos de interpretabilidad para modelos de aprendizaje automático pueden clasificarse en función de varios criterios. A continuación se exponen los más destacados:
El criterio es intrínseco si la interpretabilidad se consigue restringiendo la complejidad del modelo (regresión lineal, regresión logística, naive bayes, kNN, árbol de decisión simple, etc.), o post-hoc si se aplican métodos de interpretabilidad que analizan el modelo después de su entrenamiento.
Estadísticos para cada predictor (ej.: importancia de la variable, etc.), visualizaciones de los predictores (ej.: gráficos de dependencia parcial que muestran la predicción media de cada predictor, etc.).
Componentes propios del modelo: casos en los que los modelos son intrínsecamente interpretables, como los pesos o coeficientes aportados por los modelos lineales o la estructura o umbrales utilizados para la división de un árbol de decisión simple.
Puntos de datos: métodos que devuelven puntos de datos ya existentes o creados, aplicable a análisis de texto o imágenes sobretodo.
Aproximar un modelo “black box”, de forma global o local, con un modelo interpretable.
Estos se refieren a métodos de interpretación específicos a modelos particulares (ej.: coeficientes de regresión de un modelo lineal), o aplicables a cualquier modelo una vez este se ha entrenado (post-hoc). Los métodos agnósticos al modelo separan la explicación del tipo de modelo de aprendizaje automático del cual se quiere obtener interpretabilidad. Como ventaja frente a los métodos específicos, ofrecen flexibilidad, es decir, ofrecen la libertad al analista de escoger cualquier modelo. Comúnmente se escoge un conjunto de modelos, y no solo uno, para tratar un problema en cuestión y luego compararlos. Nota: En este documento se muestran ejemplos de métodos agnósticos al modelo.
Por un lado, los métodos globales explican el comportamiento completo del modelo sobre todo el conjunto de datos. La interpretación global ayuda a entender la distribución de la variable respuesta en función del conjunto de predictores (difícil de obtener en la práctica). Por otro lado, los métodos de interpretación locales explican predicciones de datos individuales o pequeños conjuntos de datos similares. Puede darse el caso, por ejemplo, que localmente la predicción de una observación tenga una dependencia lineal o monótona con algunas características, en lugar de una relación compleja. Los métodos locales pueden ser más precisos que los globales.
A continuación se describen ejemplos de algunos de los métodos interpretativos agnósticos al modelo, en concreto:
Otros:
La dependencia parcial es una medida de la predicción promedio de un modelo con respecto a una variable de entrada. Un gráfico de dependencia parcial (partial dependence plot) o PDP es un método de interpretabilidad global de los más simples y entendibles que muestra como cambia la predicción de la variable respuesta en función del efecto marginal de los valores de una variable de entrada de interés, mientras se toma en cuenta la no linealidad y se promedian los efectos de todas las demás variables de entrada. Graficando la dependencia entre la variable respuesta y un predictor específico podemos ver si la relación entre ambos es lineal, monótona o compleja. Así pues, obtenemos una función parcial que nos aporta, dados unos valores concretos de predictor/es, el efecto marginal promedio en la predicción.
En notación matemática, consideremos un subvector \(\small{X_S}\) de \(\small{l<p}\) predictores del vector de variables de entrada \(\small{X^T = (X_1, X_2, ..., X_p)}\), indexado por \(\small{S ⊂ C}\){\(\small{1, 2, ..., p}\)}. En este caso \(\small{X_S}\) contiene los predictores de los que queremos obtener la interpretación. Por otro lado, \(\small{C}\) es el set complementario, con \(\small{S∪C=}\){\(\small{1, 2, ..., p}\)}. Una función general \(\small{f(x)}\) dependerá, en principio, de todas las variables de entrada: \(\small{f(x)=f(X_S, X_C)}\). Por tanto, una forma de definir el promedio o la dependencia parcial de \(\small{f(x)}\) sobre \(\small{X_S}\) es
Una suposición es que los predictores en \(\small{X_S}\) no presentan correlación o interacciones fuertes con aquellos en \(\small{X_C}\).
La estimación de la función de dependencia parcial puede estimarse como
donde {\(\small{x_{1c}, x_{2c}, ..., x_{NC}}\)} son los valores de \(\small{X_C}\) en los datos de entrenamiento (\(\small{N}\).
Importante destacar que la estimación \(\small{f_S(X_S)}\) representa el efecto de \(\small{X_S}\) sobre \(\small{f(X)}\) después de tomar en cuenta el efecto (promedio) de los otros predictores en \(\small{X_C}\) sobre \(\small{f(X)}\). Por lo tanto, estas funciones NO representan el efecto de \(\small{X_S}\) sobre \(\small{f(X)}\) ignorando los efectos de \(\small{X_C}\). Este último viene dado por la expectativa condicional
que representa la mejor estimación a \(\small{f(X)}\) por mínimos cuadrados en función de únicamente \(\small{X_S}\).
Para problemas de clasificación con \(\small{K}\) clases en las que el modelo genera probabilidades, el PDP muestra la probabilidad de una determinada clase dados diferentes valores del predictor tomado en cuenta. Tendremos un modelo por cada clase, cada uno relacionado con las probabilidades respectivas mediante
Así, cada \(\small{f_k(X)}\) es una función creciente de su probabilidad respectiva en una escala logarítmica. Los PDP de cada \(\small{f_k(X)}\) respecto a un determinado predictor puede revelar como el log-odds de una determinada clase \(\small{K}\) depende del conjunto de variables de entrada.
Nota: En el caso de variables categóricas, se obtiene una estimación PDP para cada categoría.
Ventajas:
Cálculo intuitivo y fáciles de implementar
Interpretación clara de como un predictor influye en la predicción promedia del modelo si dicho predictor no presenta correlación con los demás. Nota: la interpretación puede ser causal para el modelo, pero no en la realidad.
Desventajas:
El número máximo práctico de predictores para una función de dependencia parcial es 2. Esto se debe a la limitación de no poder representar en más de tres dimensiones.
Importante complementar el gráfico con la representación de la distribución del predictor analizado, para tener en cuenta regiones con pocos datos.
Asumen independencia entre el predictor analizado y el resto. En muchos casos, esta asunción no es realista.
Pueden ocultar efectos o tendencias heterogéneas al mostrar solo los efectos marginales promedio. Una alternativa a esto puede ser la aplicación de curvas de expectativa condicional individual (ICE).
Un gráfico ICE (individual conditional expectation) visualiza la dependencia de la predicción en un predictor para cada observación por separado, lo que da como resultado una línea por observación, en comparación con una línea general en los gráficos de dependencia parcial. Se trata de un equivalente al PDP para datos individuales, ya que un PDP representa el promedio de las líneas de un diagrama ICE. Los valores para una línea (una observación) se pueden calcular manteniendo todos los otros predictores iguales, creando variantes de esta observación reemplazando el valor del predictor y obteniendo predicciones con el modelo “black box” para estas observaciones recién creadas. Otro aspecto es que, a diferencia de los PDP, los gráficos ICE proporcionan más información respecto a posibles interacciones en cuanto a la evolución de la predicción.
En notación matemática, para cada observación {\(\small{(X^{i}_S, X^{i}_C)}\)}\(\small{^N_{i=1}}\) se representa la curva \(\small{\hat{f}^{(i)}_S}\) frente a \(\small{x^{(i)}_S}\) mientras \(\small{x^{(i)}_C}\) permanece fijo.
Ventajas:
Desventajas:
Un inconveniente que puede surgir con un gráfico ICE estándar, es que puede llegar a ser difícil diferenciar si las curvas ICE individuales difieren entre observaciones ya que cada una se inicia en un valor distinto de predicción (eje y). Una solución supone centrar las curvas en un punto o valor concreto del predictor (por ejemplo, el valor mínimo de su rango) y mostrar la evolución de diferencia de la predicción respecto a este punto. Con ello obtenemos un gráfico ICE centrado (c-ICE). Las nuevas curvas vendrán definidas como
donde \(\small{1}\) es un vector de 1s con el número apropiado de dimensiones (una o dos), \(\small{\hat{f}}\) es el modelo ajustado y \(\small{x^a}\) es el punto “anclado” del predictor.
Otra forma de hacer que sea visualmente más fácil detectar la heterogeneidad es observar las derivadas individuales de la función de predicción con respecto a un predictor. Las derivadas de una función (o curva) pueden indicar si ocurren cambios y en qué dirección ocurren. Con el gráfico ICE derivado es fácil detectar rangos de valores en el predictor donde las predicciones del modelo cambian para (al menos algunas) las observaciones. Si no hay interacción entre el predictor analizado \(\small{x_S}\) y el resto de predictores \(\small{x_C}\), entonces la función de predicción se puede expresar como:
donde \(\small{g'(x_S)}\) es la derivada de la función de predicción con respecto a la función en \(\small{S}\).
Con ausencia de interacciones, las derivadas parciales individuales deberían ser las mismas para todas las observaciones, pero si existen interacciones, se hacen visibles en el gráfico d-ICE. Además de mostrar las curvas individuales para la derivada, mostrar su desviación estándar ayuda a resaltar regiones con heterogeneidad.
Nota: Un inconveniente de aplicar la gráfica ICE derivada es el tiempo de computación.
El método Local Interpretable Model Explanation o LIME es un método local aplicable para explicar predicciones individuales, obteniendo la contribución de cada predictor sobre la predicción de una observación individual. LIME permite interpretar qué ocurre con las predicciones de un modelo black box cuando los datos varían, esto es, a partir de una observación que se quiere interpretar, se generan nuevas observaciones vecinas permutadas modificadas a partir de la original tomando valores de una distribución normal con media y varianza respectivos de cada predictor. Sobre estas nuevas observaciones se entrena un modelo lineal interpretable (modelo local sustituto) que debe aportar una buena aproximación de la predicción de manera local (no tiene por qué ser una buena aproximación global). En resumen, LIME genera un nuevo set de datos formado por muestras permutadas y las correspondientes predicciones del modelo black box. Sobre estos nuevos datos LIME entrena un modelo interpretable el cual se pondera por la proximidad de las observaciones muestreadas a la observación de interés.
En notación matemática, la explicabilidad de una observación \(\small{x}\) viene dada por:
donde \(\small{G}\) es el conjunto de todas las posibles explicaciones (ej.: todos los posibles modelos de regresión lineal), \(\small{g}\) el modelo local sustituto y \(\small{L(f,g,\pi_x)}\) la medida de la fidelidad de la aproximación de \(\small{f}\) por \(\small{g}\) en la proximidad de \(\small{x}\) definida por \(\small{\pi_x}\).
Es decir, el modelo explicativo para la observación \(\small{x}\) es aquel que minimiza el error \(\small{L}\) (por ejemplo MSE en regresión), el cual mide la proximidad de la explicación y la predicción del modelo original black box, a la vez que se intenta mantener baja la complejidad del modelo sustituto (\(\small{\Omega(g)}\)). Con esto, método LIME cuenta con la suposición que todo modelo complejo es aproximable linealmente a escala muy local.
Se calcula la vecindad de cada observación sintética creada como la distancia a la observación original mediante una función kernel (radial), de manera que a las observaciones sintéticas más cercanas se les asigna una mayor importancia o peso. Sin embargo, no es trivial escoger el hiperparámetro de este kernel, la anchura, que es lo que determina el tamaño de la vecindad (un ancho de kernel pequeño significa que una observación sintética debe estar muy cerca de la original para influir en el modelo local).
De forma esquemática, los pasos seguidos por LIME son:
Ventajas:
Desventajas:
Para problemas de regresión mostraremos un ejemplo utilizando el set de datos bike
procesado por el autor del libro “Interpretable machine learning” Christoph Molnar, descargable en su repositorio GitHub que contiene datos sobre recuentos diarios de bicicletas alquiladas de la empresa Capital-Bikeshare en Washington D.C., junto con información meteorológica y estacional. Las variables disponibles (12) son:
season
: estación del añoyr
: añomnt
: mes del añoholiday
: indica dia festivo o no festivoweekday
: día de la semanaworkingday
: indica si el día es o no laborableweathersit
: situación metereológicatemp
: temperatura (ºC)hum
: humedadwindspeed
: velocidad del vientocnt
: número de bicicletas alquiladas (variable a predecir)days_since_2011
: días desde el primer día de toma de datosEl objetivo es interpretar la prediccón del número de alquiler de bicicletas para un determinado día.
# Cargamos el fichero desde CSV
bike <- read.csv("bike.csv", header = TRUE)
str(bike)
## 'data.frame': 731 obs. of 12 variables:
## $ season : chr "SPRING" "SPRING" "SPRING" "SPRING" ...
## $ yr : int 2011 2011 2011 2011 2011 2011 2011 2011 2011 2011 ...
## $ mnth : chr "JAN" "JAN" "JAN" "JAN" ...
## $ holiday : chr "NO HOLIDAY" "NO HOLIDAY" "NO HOLIDAY" "NO HOLIDAY" ...
## $ weekday : chr "SAT" "SUN" "MON" "TUE" ...
## $ workingday : chr "NO WORKING DAY" "NO WORKING DAY" "WORKING DAY" "WORKING DAY" ...
## $ weathersit : chr "MISTY" "MISTY" "GOOD" "GOOD" ...
## $ temp : num 8.18 9.08 1.23 1.4 2.67 ...
## $ hum : num 80.6 69.6 43.7 59 43.7 ...
## $ windspeed : num 10.7 16.7 16.6 10.7 12.5 ...
## $ cnt : int 985 801 1349 1562 1600 1606 1510 959 822 1321 ...
## $ days_since_2011: int 0 1 2 3 4 5 6 7 8 9 ...
Ajustaremos un modelo random forest con caret
a partir del cual poder mostrar los gráficos de interpretabilidad. Para más información sobre este tipo de modelo, visitar Árboles de decisión y métodos de ensemble.
library(caret)
library(dplyr)
library(parallel)
library(doParallel)
cluster <- makeCluster(detectCores() - 1)
registerDoParallel(cluster)
# Obviamos la variable "days_since_2011" y "yr"
bike <- select(bike, -c(days_since_2011, yr))
# Reservamos dos observaciones para que hagan el papel de observaciones nuevas
# sobre las que predecir el número de alquileres: un día de invierno y otro de
# verano del último año
indices_new <- c(531, 720)
bike_new <- bike[indices_new, ] %>% select(-cnt)
# Datos entrenamiento (se excluyen las dos observaciones anteriores)
bike_train <- bike[-indices_new, ]
# Método de validación cruzada (10-fold)
fitControl <- trainControl(method = "cv",
number = 10,
search = "grid",
allowParallel = TRUE)
# Hiperparámetro a optimizar: número de predictores aleatorios en cada ramificación.
grid_mtry <- expand.grid(mtry = c(2:8))
# Ajuste del modelo random forest
set.seed(356)
modelo_rf <- train(cnt ~ ., data = bike_train,
method = "rf",
metric = "RMSE",
tuneGrid = grid_mtry,
trControl = fitControl)
# Obtenemos la predicción sobre las observaciones o días "nuevos"
pred_modelo_rf <- data.frame(cnt = predict(modelo_rf, bike_new))
# Unimos las predicciones al resto de datos de los días nuevos
bike_new <- cbind(bike_new, pred_modelo_rf)
bike_new
## season mnth holiday weekday workingday weathersit temp hum
## 531 SUMMER JUN NO HOLIDAY THU WORKING DAY GOOD 22.47165 56.9583
## 720 WINTER DEZ NO HOLIDAY THU WORKING DAY MISTY 7.51000 66.7917
## windspeed cnt
## 531 17.000111 5969.876
## 720 8.875021 4686.378
Para obtener el gráfico PDP podemos aplicar la función partial()
del paquete pdp
. De entre los argumentos disponibles, utilizaremos:
object
: modelo ajustadopred.var
: predictores de interés (no más de 3)plot
: TRUE
para graficar directamente o FALSE
para obtener un df con los valores de dependencia parcialrug
: incluir marcas en el eje del predictorlibrary(pdp)
# Gráfico PDP (1 predictor)
partial(modelo_rf, pred.var = "temp", plot = TRUE, rug = TRUE)
El eje yhat del gráfico muestra el número de bicicletas alquiladas predichas. Las marcas azules en el eje inferior representan la distribución de la variable representada, lo cual muestra la relevancia de las regiones para su interpretación (extremar la precaución al interpretar regiones con pocos datos).
También podemos obtener la relación entre dos predictores mediante un heatmap:
# Gráfico PDP (2 predictores numéricos)
partial(modelo_rf, pred.var = c("temp","hum"), plot = TRUE, rug = TRUE)
Para calcular las curvas ICE está a nuestra disposición la función ice()
del paquete ICEbox
. Entre sus argumentos, utilizaremos:
object
: modeloX
: predictoresy
: variable de salida predichapredictor
: predictor sobre el que obtener las curvas ICEPara obtener el gráfico, aplicaremos la función plot()
con argumentos:
x
: objeto de clase ice
frac_to_plot
: proporción de curvas a graficar (1 = todas)plot_orig_pts_preds
: marca el valor de la predicción original del modelopts_preds_size
: tamaño de los puntos en plot_orig_pts_preds
color_by
: predictor por el que colorear las curvasrug_quantile
: vector de quantiles para especificar las marcas de distribución de valores del predictor en el eje xplot_pdp
: superponer la curva PDPlibrary(ICEbox)
# Grafico ICE para el predictor "temp"
ice_temp <- ice(object = modelo_rf,
X = bike_train[which(names(bike_train) != "cnt")],
y = bike_train$cnt,
predictor = "temp",
verbose = FALSE)
plot(x = ice_temp,
frac_to_plot = 0.7, #70% de observaciones
plot_orig_pts_preds = TRUE,
pts_preds_size = 1,
color_by = "weathersit",
rug_quantile = seq(from = 0, to = 1, by = 0.01),
plot_pdp = TRUE,
main = "Gráfico ICE: temperatura")
## ICE Plot Color Legend
## weathersit color
## GOOD firebrick3
## MISTY dodgerblue3
## RAIN/SNOW/STORM gold1
En este caso, para la temperatura podemos ver que todas las curvas (observaciones) siguen la misma tendencia: más alquiler de bicicletas en días con temperaturas en torno a 20ºC, llegando a un máximo de en torno a 8000 bicicletas en un día. Parece que para los días con mal clima (lluvia/nieve/tormenta), marcado por las curvas en amarillo, el alquiler de bicicletas es menor.
Para obtener un gráfico ICE centrado, especificamos el argumento centered = TRUE
:
# Grafico ICE centrado para el predictor "temp" a partir de su valor mínimo
plot(x = ice_temp,
frac_to_plot = 0.7, #70% de observaciones
plot_orig_pts_preds = TRUE,
pts_preds_size = 1,
color_by = "weathersit",
rug_quantile = seq(from = 0, to = 1, by = 0.01),
plot_pdp = TRUE,
centered = TRUE,
main = "Gráfico c-ICE: temperatura")
## ICE Plot Color Legend
## weathersit color
## GOOD firebrick3
## MISTY dodgerblue3
## RAIN/SNOW/STORM gold1
El gráfico centrado muestra un aumento en el alquiler con un aumento de la temperatura, manteniéndose más constante en torno a los 20ºC. Este aumento es mucho menos pronunciado respecto a los alquileres en días con mal tiempo.
Para obtener un gráfico ICE derivado está a nuestra disposición la función dice()
del paquete ICEbox
, que aplicaremos al objeto de clase ice
creado anteriormente. Al graficar, el argumento plot_orig_pts_preds
cambia por plot_orig_pts_deriv
. También superponemos la curva PDP derivada:
dice_temp <- dice(ice_obj = ice_temp)
# Gráfico ICE derivado
plot(x = dice_temp,
frac_to_plot = 0.7,
plot_orig_pts_deriv = TRUE,
pts_preds_size = 1,
#color_by = "weathersit",
rug_quantile = seq(from = 0, to = 1, by = 0.01),
plot_dpdp = TRUE,
main = "Gráfico d-ICE: temperatura")
## NULL
El gráfico muestra en la parte inferior la desviación estándar de las curvas derivadas. Con ausencia de interacciones, las derivadas parciales individuales deberían ser las mismas para todas las observaciones, pero parecen haber regiones donde la tendencia es más dispar (las derivadas se alejan de 0), más en torno a 5ºC y 12ºC. Esto podría ser indicativo de interacciones de la temperatura con otros predictores.
Podemos obtener medidas de interacción de los predictores disponibles, como valor de cuanta varianza en \(\small{f(x)}\) es explicada por la interacción. El rango de esta medida se encuentra entre 0 (sin interacción) y 1 (100% de varianza de \(\small{f(x)}\) debida a interacciones). Con el paquete iml
podemos obtener estos valores. Comenzamos creando un objeto de la clase Predictor
que contenga el modelo y los datos, para a continuación aplicar la clase Interaction
:
library(iml)
predictor <- Predictor$new(modelo_rf,
data = bike_train[which(names(bike) != "cnt")],
y = bike_train$cnt)
plot(iml::Interaction$new(predictor)) #especifico el paquete para evitar conflictos
Podemos también obtener las medidas de interacción entre pares escogiendo un predictor de interés, por ejemplo la interacción de la temperatura con todos los demás predictores:
plot(iml::Interaction$new(predictor, feature = "temp"))
Para interpretar prediccciones por el método LIME en R
, primero aplicamos la función lime()
del paquete lime
que nos devuelve un objeto con el modelo y estadísticos de la distribución de los predictores de los datos de entrenamiento, en concreto de la distribución del nivel de cada variable categórica y cada variable continua dividida en particiones o bins (4 por defecto). Estos estadísticos se utilizarán para la permutación de las observaciones a explicar. De entre los argumentos disponibles para los datos tabulares, encontramos:
x
: datos de entrenamiento (sin la variable respuesta)model
: modelo black boxpreprocess
: transformar un vector de tipo character
al formato esperado por el modelobin_continuous
: particionar variables contínuas a la hora de obtener la explicaciónn_bins
: número de particiones (bins) si bin_continuous = TRUE
A continuación con la función explain()
sobre el objeto lime
aplicamos el algoritmo LIME para obtener la interpretabilidad de la predicción (por defecto aplica Ridge regression). Entre los argumentos para los datos tabulares, encontramos:
x
: nuevas observaciones sobre las que obtener la explicabilidadexplainer
: toma el objeto de tipo explainer
para aplicar las permutaciones. En este caso el objeto lime
.label
: clases a explicar en caso de que el modelo sea un clasificadorn_permutations
: nº de permutaciones a crear para la observación a explicarn_features
: número de predictores a utilizar en la explicaciónfeature_select
: algoritmo para selección de predictores. Entre los disponibles: forward_selection
, highest_weights
para seleccionar las n_features con mayor peso absoluto, lasso_path
y tree
para selección de predictores mediante un árbol simple.dist_fun
: función para calcular la distancia entre las permutaciones y la observación original. Por defecto la distancia de Gower, pero también se puede apliar la euclídea, manhattan, etc.kernel_width
: anchura de kernel para convertir la distancia en medida de similitud. El valor predeterminado es 0,75 veces la raíz cuadrada del número de predictores. Valores más pequeños restringen el tamaño de la región local.library(lime)
# Creamos el objeto lime "explicador"
explainer_lime <- lime(x = bike_train[which(names(bike_train) != "cnt")],
model = modelo_rf,
bin_continuous = TRUE,
quantile_bins = FALSE)
summary(explainer_lime)
## Length Class Mode
## model 24 train list
## preprocess 1 -none- function
## bin_continuous 1 -none- logical
## n_bins 1 -none- numeric
## quantile_bins 1 -none- logical
## use_density 1 -none- logical
## feature_type 9 -none- character
## bin_cuts 9 -none- list
## feature_distribution 9 -none- list
# Obtenemos la explicación sobre las observaciones de test
set.seed(878)
lime_bike <- explain(x = bike_new[which(names(bike_new) != "cnt")],
explainer = explainer_lime,
n_permutations = 5000,
dist_fun = "manhattan",
kernel_width = 1,
n_features = 5,
feature_select = "lasso_path")
tibble::glimpse(lime_bike)
## Rows: 10
## Columns: 11
## $ model_type <chr> "regression", "regression", "regression", "regress...
## $ case <chr> "531", "531", "531", "531", "531", "720", "720", "...
## $ model_r2 <dbl> 0.4224099, 0.4224099, 0.4224099, 0.4224099, 0.4224...
## $ model_intercept <dbl> 2997.328, 2997.328, 2997.328, 2997.328, 2997.328, ...
## $ model_prediction <dbl> 5749.258, 5749.258, 5749.258, 5749.258, 5749.258, ...
## $ feature <chr> "mnth", "workingday", "weathersit", "temp", "hum",...
## $ feature_value <chr> "JUN", "WORKING DAY", "GOOD", "22.471651", "56.958...
## $ feature_weight <dbl> 215.3284, 207.8872, 364.3006, 1426.3322, 538.0825,...
## $ feature_desc <chr> "mnth = JUN", "workingday = WORKING DAY", "weather...
## $ data <list> [["SUMMER", "JUN", "NO HOLIDAY", "THU", "WORKING ...
## $ prediction <dbl> 5969.876, 5969.876, 5969.876, 5969.876, 5969.876, ...
El resultado del objeto explain
es un data frame que contiene información de las predicciones del modelo simple. Pero una manera más rápida de interpretar los resultados es visualizándolos con la función plot_features()
:
# Visualización del resultado
plot_features(lime_bike)
Para cada observación se proporciona el \(\small{R^2}\) (Explanation fit) del modelo simple ajustado a la localidad de la observación original. Podemos observar que para la observación 531 correspondiente a un día de verano, el modelo original RF predice 5969 alquileres de bicicletas. La característica del día que más propicia los alquileres es una temperatura de entre 15 y 22 ºC. Por otro lado, para la observación 720 correspondiente a un día de invierno el modelo original RF predice menos alquileres, 4686, con un \(\small{R^2}\) del modelo simple más alto. Destaca que una temperatura inferior a en torno 8ºC afecta negativamente al alquiler de bicicletas. Tener en cuenta en esta comparación que el coeficiente \(\small{R^2}\) para la observación 531 es menor que para la 720, por lo que el resultado de la explicabilidad puede ser menos fiable.
También podemos obtener la visualización en forma de mapa de calor con la función plot_explanations()
, donde podemos comparar observaciones y detectar predictores influyentes comunes:
plot_explanations(lime_bike)
NOTA: Tener en cuenta que los métodos y valores escogidos para aplicar el método LIME podrían tratarse como hiperparámetros a optimizar, ya que en función de unos parámetros u otros, los resultados pueden variar.
Para problemas de regresión mostraremos un ejemplo utilizando una versión del set de datos biopsy
del paquete MASS
sin valores nulos y con las variables renombradas. Este set contiene datos de biopsias procedentes de pacientes con cáncer de mama, con un conjunto de 9 características de la biopsia con valores en una escala de 1 a 10. La variable respuesta es el tipo de cáncer: benigno o maligno. Analizaremos también las predicciones de un modelo Random Forest sobre la probabilidad del tipo de cáncer dadas las características de la biopsia.
# Cargamos el fichero desde CSV
biopsy <- read.csv("biopsy.csv", header = TRUE)
str(biopsy)
## 'data.frame': 683 obs. of 10 variables:
## $ clump_thickness : int 5 5 3 6 4 8 1 2 2 4 ...
## $ uniformity_cell_size : int 1 4 1 8 1 10 1 1 1 2 ...
## $ uniformity_cell_shape : int 1 4 1 8 1 10 1 2 1 1 ...
## $ marginal_adhesion : int 1 5 1 1 3 8 1 1 1 1 ...
## $ single_epithelial_cell_size: int 2 7 2 3 2 7 2 2 2 2 ...
## $ bare_nuclei : int 1 10 2 4 1 10 10 1 1 1 ...
## $ bland_chromatin : int 3 3 3 3 3 9 3 3 1 2 ...
## $ normal_nucleoli : int 1 2 1 7 1 7 1 1 1 1 ...
## $ mitoses : int 1 1 1 1 1 1 1 1 5 1 ...
## $ class : chr "benign" "benign" "benign" "benign" ...
# Reservamos dos observaciones para que hagan el papel de observaciones nuevas
# a predecir: una biopsia benigna (680) y otra maligna (681)
indices_new <- c(680, 681)
biopsy_new <- biopsy[indices_new, ] %>% dplyr::select(-class)
biopsy_train <- biopsy[-indices_new, ]
# Método de validación cruzada (10-fold)
fitControl <- trainControl(method = "cv",
number = 10,
search = "grid",
summaryFunction = twoClassSummary,
classProbs = TRUE,
allowParallel = TRUE)
# Hiperparámetro a optimizar: número de predictores aleatorios en cada ramificación.
grid_mtry <- expand.grid(mtry = c(2:8))
# Ajuste del modelo random forest
set.seed(356)
modelo_rf <- caret::train(class ~ ., data = biopsy_train,
method = "rf",
metric = "ROC",
prob.model = TRUE,
tuneGrid = grid_mtry,
trControl = fitControl)
stopCluster(cluster)
registerDoSEQ()
# Obtenemos la predicción sobre las observaciones de test
pred_modelo_rf <- data.frame(class = predict(object = modelo_rf,
newdata = biopsy_new,
type = "prob"))
pred_modelo_rf
## class.benign class.malignant
## 680 1.000 0.000
## 681 0.038 0.962
# Unimos las predicciones de la probabilidad de malignidad a los datos de las
# biopsias "nuevas"
biopsy_new <- biopsy_new %>% cbind(class = pred_modelo_rf$class.malignant)
biopsy_new
## clump_thickness uniformity_cell_size uniformity_cell_shape
## 680 2 1 1
## 681 5 10 10
## marginal_adhesion single_epithelial_cell_size bare_nuclei bland_chromatin
## 680 1 2 1 1
## 681 3 7 3 8
## normal_nucleoli mitoses class
## 680 1 1 0.000
## 681 10 2 0.962
Para obtener el gráfico PDP aplicamos de nuevo la función partial()
del paquete pdp
. Debemos indicar en el argumento pred.fun
una función que devuelva las predicciones del modelo black box en forma de probabilidades:
# Función para obtener la probabilidad media de malignidad a partir del modelo RF
prob_fun_pdp <- function(object, newdata){
mean(predict(object, newdata, type = "prob")[, 2])
}
# PDP (1 predictor)
pdp_biopsy <- pdp::partial(modelo_rf,
pred.fun = prob_fun_pdp,
pred.var = "mitoses",
rug = TRUE)
plotPartial(pdp_biopsy, xlab = "Mitosis", ylab = "P(Maligno)")
El gráfico sugiere que la probabilidad de malignidad de un tumor aumenta ligeramente con el nivel de división celular.
También podemos obtener la relación entre dos predictores mediante un heatmap:
# PDP (2 predictores numéricos)
pdp_biopsy <- pdp::partial(modelo_rf,
pred.fun = prob_fun_pdp,
pred.var = c("mitoses", "bare_nuclei"),
rug = TRUE)
plotPartial(pdp_biopsy)
El resultado de comparar la mitosis con el estado de los núcleos celulares sugiere que la probabilidad de malignidad de un tumor aumenta con un mayor nivel de núcleos desnudos sin apenas efecto de cual sea el grado de mitosis.
Para calcular las curvas ICE aplicamos de nuevo la función ice()
del paquete ICEbox
. Para problemas de clasificación, entran en juego argumentos como:
logodds
: asignar TRUE
si queremos graficar el logaritmo de odds a partir de la probabilidadAdemás, las predicciones devueltas por el modelo RF tienen que ser en forma de probabilidades, por lo que indicamos con el argumento predictfcn
:
prob_fun_ice <- function(object, newdata){
predict(object, newdata, type = "prob")[, 2]
}
# Curvas ICE en función del estado de mitosis o división celular de la biopsia
ice_biopsy <- ice(object = modelo_rf,
X = biopsy_train[which(names(biopsy_train) != "class")],
predictor = "bare_nuclei",
#logodds = TRUE,
predictfcn = prob_fun_ice,
verbose = FALSE)
## y not passed, so range_y is range of ice curves and sd_y is sd of predictions on real observations
plot(x = ice_biopsy,
frac_to_plot = 0.7, #70% de observaciones
plot_orig_pts_preds = TRUE,
pts_preds_size = 1,
rug_quantile = seq(from = 0, to = 1, by = 0.01),
plot_pdp = TRUE,
main = "Gráfico ICE: mitosis",
xlab = "Mitosis",
ylab = "P(Maligno)")
Mientras que para algunas muestras no parece afectar el nivel de mitosis en la probabilidad de que el tumor sea maligno (curvas superiores), para otras este aumento en el nivel de división celular sí parece afectar la probabilidad, en algunos casos para niveles superiores a 5 especialmente..
Para obtener curvas ICE centradas, especificamos el argumento centered = TRUE
en el código anterior:
# Grafico ICE centrado para el predictor "temp" a partir de su valor mínimo
plot(x = ice_biopsy,
frac_to_plot = 0.7, #70% de observaciones
plot_orig_pts_preds = TRUE,
pts_preds_size = 1,
rug_quantile = seq(from = 0, to = 1, by = 0.01),
plot_pdp = TRUE,
centered = TRUE,
main = "Gráfico c-ICE: mitosis")
En comparación con los casos con menor división celular, la probabilidad de que el tumor sea maligno en algunos casos aumenta con una mayor división celular. En algunos otros, la probabilidad no cambia.
Para obtener un gráfico ICE derivado está a nuestra disposición la función dice()
del paquete ICEbox
, que aplicaremos al objeto de clase ice
creado anteriormente. Al graficar, el argumento plot_orig_pts_preds
cambia por plot_orig_pts_deriv
. También superponemos la curva PDP derivada:
dice_biopsy <- dice(ice_obj = ice_biopsy)
# Gráfico ICE derivado
plot(x = dice_biopsy,
frac_to_plot = 0.7,
plot_orig_pts_deriv = TRUE,
pts_preds_size = 1,
rug_quantile = seq(from = 0, to = 1, by = 0.01),
plot_dpdp = TRUE,
main = "Gráfico d-ICE: mitosis")
## NULL
predictor <- Predictor$new(modelo_rf,
data = biopsy_train[which(names(biopsy) != "class")],
y = biopsy_train$class)
plot(iml::Interaction$new(predictor))
# Creamos el objeto lime "explicador"
explainer_lime <- lime(x = biopsy_train[which(names(biopsy_train) != "class")],
model = modelo_rf,
bin_continuous = TRUE,
quantile_bins = FALSE)
summary(explainer_lime)
## Length Class Mode
## model 23 train list
## preprocess 1 -none- function
## bin_continuous 1 -none- logical
## n_bins 1 -none- numeric
## quantile_bins 1 -none- logical
## use_density 1 -none- logical
## feature_type 9 -none- character
## bin_cuts 9 -none- list
## feature_distribution 9 -none- list
# Obtenemos la explicación sobre las observaciones de test
lime_biopsy <- explain(x = biopsy_new[which(names(biopsy_new) != "class")],
explainer = explainer_lime,
n_labels = 1,
n_features = 4)
tibble::glimpse(lime_biopsy)
## Rows: 8
## Columns: 13
## $ model_type <chr> "classification", "classification", "classificatio...
## $ case <chr> "680", "680", "680", "680", "681", "681", "681", "...
## $ label <chr> "benign", "benign", "benign", "benign", "malignant...
## $ label_prob <dbl> 1.000, 1.000, 1.000, 1.000, 0.962, 0.962, 0.962, 0...
## $ model_r2 <dbl> 0.561036, 0.561036, 0.561036, 0.561036, 0.377859, ...
## $ model_intercept <dbl> 0.02568876, 0.02568876, 0.02568876, 0.02568876, 0....
## $ model_prediction <dbl> 0.7944718, 0.7944718, 0.7944718, 0.7944718, 0.7557...
## $ feature <chr> "bare_nuclei", "uniformity_cell_size", "uniformity...
## $ feature_value <int> 1, 1, 1, 1, 3, 10, 10, 10
## $ feature_weight <dbl> 0.26610602, 0.20602612, 0.16938649, 0.12726443, -0...
## $ feature_desc <chr> "bare_nuclei <= 3.25", "uniformity_cell_size <= 3....
## $ data <list> [[2, 1, 1, 1, 2, 1, 1, 1, 1], [2, 1, 1, 1, 2, 1, ...
## $ prediction <list> [[1, 0], [1, 0], [1, 0], [1, 0], [0.038, 0.962], ...
# Visualización del resultado
plot_features(lime_biopsy)
plot_explanations(lime_biopsy)
El modelo RF clasifica correctamente como benigna la biopsia 680 con una probabilidad del 100%, mientras que la biopsia 681 se clasifica correctamente como maligna con una probabilidad del 96%. Entre otros, la presencia de núcleos desnudos con un valor tomado igual o menor a 3,25 parece favorecer la probabilidad de que el cáncer sea benigno, al contrario que en los tumores malignos.
The Elements of Statistical Learning: Data Mining, Inference, and Prediction. New York: Springer, 2001.
https://www.h2o.ai/wp-content/uploads/2017/09/driverlessai/interpreting.html
Interpretable Machine Learning: A Guide for Making Black Box Models Explainable (2020). Christoph Molnar.
Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. 2016. “Why Should I Trust You?”: Explaining the Predictions of Any Classifier. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD ’16). ACM, New York, NY, USA, 1135-1144. DOI: https://doi.org/10.1145/2939672.2939778
This work by Cristina Gil Martínez is licensed under a Creative Commons Attribution 4.0 International License.