Introduction

This code below works through the iml package. The code and examples are from cran’s website C-RAN-Link.

Machine learning models usually perform really well for predictions, but are not interpretable. The iml package provides tools for analysing any black box machine learning model:

Feature importance: Which were the most important features?

Feature effects: How does a feature influence the prediction? (Accumulated local effects, partial dependence plots and individual conditional expectation curves)

Explanations for single predictions: How did the feature values of a single data point affect its prediction? (LIME and Shapley value)

Surrogate trees: Can we approximate the underlying black box model with a short decision tree?

The iml package works for any classification and regression machine learning model: random forests, linear models, neural networks, xgboost, etc.

This document shows you how to use the iml package to analyse machine learning models.

If you want to learn more about the technical details of all the methods, read chapters from: Chris Git Hub

Data

This example uses the Boston housing dataset.

data("Boston", package = "MASS")

head(Boston)
##      crim zn indus chas   nox    rm  age    dis rad tax ptratio  black lstat
## 1 0.00632 18  2.31    0 0.538 6.575 65.2 4.0900   1 296    15.3 396.90  4.98
## 2 0.02731  0  7.07    0 0.469 6.421 78.9 4.9671   2 242    17.8 396.90  9.14
## 3 0.02729  0  7.07    0 0.469 7.185 61.1 4.9671   2 242    17.8 392.83  4.03
## 4 0.03237  0  2.18    0 0.458 6.998 45.8 6.0622   3 222    18.7 394.63  2.94
## 5 0.06905  0  2.18    0 0.458 7.147 54.2 6.0622   3 222    18.7 396.90  5.33
## 6 0.02985  0  2.18    0 0.458 6.430 58.7 6.0622   3 222    18.7 394.12  5.21
##   medv
## 1 24.0
## 2 21.6
## 3 34.7
## 4 33.4
## 5 36.2
## 6 28.7

Fitting the ML Model

Use a randomForset to precit the Boston median housing value:

set.seed(42)

library(iml) #interpretable machine learning package
## Warning: package 'iml' was built under R version 4.2.3
library(randomForest)
##Random Forest Train
rf <- randomForest(medv ~ ., data = Boston, ntree = 10)

Using the iml Predictor() container

We create a Predictor object, that holds the model and the data. The iml package uses R6 classes: New objects can be created by calling Predictor$new().

## Create a dataframe that has all of the features except the target
## In this case, remove the median house value (medv)
x = Boston[which(names(Boston) != "medv")]

## Store the data and medv in the Predictor() container
## parameters are the model - in this case rf
## the data - in this case the new data frame without medv - x
## and the target - the medv in the Boston data frame
predictor = Predictor$new(model = rf,
                          data = x,
                          y = Boston$medv)

Feature Importance

We can measure how important each feature was for the predictions with FeatureImp. The feature importance measure works by shuffling each feature and measuring how much the performance drops. For this regression task we choose to measure the loss in performance with the mean absolute error (‘mae’), another choice would be the mean squared error (‘mse’).

Once we create a new object of FeatureImp, the importance is automatically computed. We can call the plot() function of the object or look at the results in a data.frame.

## Store the features in a FeatureImp object
## loss argument specifies the performance measure for error
imp = FeatureImp$new(predictor = predictor,
                     loss = "rmse")
## Visualize the importance using ggplot2
library(ggplot2)

##Plot
plot(imp)

## View the the feature importance percentiles
imp$results
##    feature importance.05 importance importance.95 permutation.error
## 1    lstat      4.043260   4.089973      4.244381          6.876696
## 2       rm      2.767377   2.784413      2.951286          4.681585
## 3      dis      1.742154   1.792560      1.803989          3.013929
## 4  ptratio      1.643280   1.712529      1.774571          2.879368
## 5      nox      1.487913   1.572376      1.603722          2.643722
## 6     crim      1.503800   1.527006      1.593206          2.567439
## 7    indus      1.295132   1.320279      1.347632          2.219858
## 8      age      1.216319   1.236589      1.269378          2.079144
## 9      tax      1.171525   1.193611      1.213924          2.006883
## 10   black      1.154155   1.174407      1.195662          1.974595
## 11     rad      1.032841   1.044418      1.055014          1.756037
## 12    chas      1.015911   1.033316      1.035738          1.737371
## 13      zn      1.021592   1.029905      1.037617          1.731636

Feature Effects

Besides knowing which features were important, we are interested in how the features influence the predicted outcome. The FeatureEffect class implements accumulated local effect plots, partial dependence plots and individual conditional expectation curves. The following plot shows the accumulated local effects (ALE) for the feature ‘lstat’. ALE shows how the prediction changes locally, when the feature is varied. The marks on the x-axis indicates the distribution of the ‘lstat’ feature, showing how relevant a region is for interpretation (little or no points mean that we should not over-interpret this region).

## Reminder this allows us to interpret areas of the curve that may be more important than others
##grid.size argument is the number of quantiles specified

ale = FeatureEffect$new(predictor = predictor,
                        feature = "lstat",
                        grid.size = 10)

ale$plot()

If we want to compute the partial dependence curves on another feature, we can simply reset the feature:

##Reset the feature to number of rooms per dwelling (rm)
ale$set.feature("rm")

##Replot
ale$plot()

Measure Interactions

We can also measure how strongly features interact with each other. The interaction measure regards how much of the variance of f(x) is explained by the interaction. The measure is between 0 (no interaction) and 1 (= 100% of variance of f(x) due to interactions). For each feature, we measure how much they interact with any other feature:

##Set up the interactions wrapper
##Play around with grid.size

interact = Interaction$new(predictor = predictor,
                           grid.size = 15)
## 
## Attaching package: 'withr'
## The following objects are masked from 'package:rlang':
## 
##     local_options, with_options
## The following object is masked from 'package:tools':
## 
##     makevars_user
##Plot the features to see how the interact with any other feature in the data
plot(interact)

We can also specify a feature and measure all it’s 2-way interactions with all other features:

## See how the features interact with crim

interact <- Interaction$new(predictor = predictor,
                            feature = "crim",
                            grid.size = 15)

# Visualize
plot(interact)

You can also plot the feature effects for all features at once:

##Plot all of the feature effects at once

effs = FeatureEffects$new(predictor = predictor,
                          grid.size = 10)

# Visualize
plot(effs)

Surrogate Model

Another way to make the models more interpretable is to replace the black box with a simpler model - a decision tree. We take the predictions of the black box model (in our case the random forest) and train a decision tree on the original features and the predicted outcome. The plot shows the terminal nodes of the fitted tree. The maxdepth parameter controls how deep the tree can grow and therefore how interpretable it is.

tree = TreeSurrogate$new(predictor = predictor,
                         maxdepth = 2)
## Loading required package: partykit
## Loading required package: libcoin
## Loading required package: mvtnorm
# Visualize
plot(tree)

# Use this code to plot the random forest tree itself
plot(tree$tree)

We can use the tree to make predictions:

head(tree$predict(Boston))
## Warning in self$predictor$data$match_cols(data.frame(newdata)): Dropping
## additional columns: medv
##     .y.hat
## 1 28.78541
## 2 21.91311
## 3 28.78541
## 4 28.78541
## 5 28.78541
## 6 28.78541

PDPs with iml

# Testing the iml package to draw pdps

#lstat variable as an example
#store it in an object
rf.lstat = Partial$new(predictor = predictor,
                        feature = "lstat",
                        aggregation = "pdp",
                        ice = TRUE)
## Warning: 'Partial' is deprecated.
## Use 'FeatureEffect' instead.
## See help("Deprecated")
# center (this centers the impact of y hat on a starting value)
# in this case it centers it at the minimum value for Boston lstat and the curve is the deviation from this value
rf.lstat$center(min(Boston$lstat))

# plot
p1 = plot(rf.lstat) + ggtitle("Random Forest")

p1

Is the yellow line the average lstat pdp? I think so.

Explain Single Predictions with a local model

Global surrogate model can improve the understanding of the global model behaviour. We can also fit a model locally to understand an individual prediction better. The local model fitted by LocalModel is a linear regression model and the data points are weighted by how close they are to the data point for which we want to explain the prediction.

##Looking at the first row of our dataframe
lime.explain = LocalModel$new(predictor = predictor,
                              x.interest = x[1,])
## Loading required package: glmnet
## Loading required package: Matrix
## Loaded glmnet 4.1-7
## Loading required package: gower
# View results
lime.explain$results
##               beta x.recoded    effect x.original feature feature.value
## rm       4.5149417     6.575 29.685741      6.575      rm      rm=6.575
## ptratio -0.5696891    15.300 -8.716243       15.3 ptratio  ptratio=15.3
## lstat   -0.4592951     4.980 -2.287290       4.98   lstat    lstat=4.98
plot(lime.explain)

Explain single predictions with game theory

An alternative for explaining individual predictions is a method from coalitional game theory named Shapley value. Assume that for one data point, the feature values play a game together, in which they get the prediction as a payout. The Shapley value tells us how to fairly distribute the payout among the feature values.

shapley = Shapley$new(predictor = predictor,
                      x.interest = x[1,],
                      sample.size = 50)

# Visualize
shapley$plot()

##Reuse the shapley object to explain other data points
shapley$explain(x.interest = x[2,])

#visualize
shapley$plot()

The results in data.frame form can be extracted like this:

results = shapley$results

head(results)
##   feature         phi    phi.var feature.value
## 1    crim  0.11943195 1.09943747  crim=0.02731
## 2      zn -0.03799000 0.03052437          zn=0
## 3   indus -0.37371333 0.86417831    indus=7.07
## 4    chas -0.02722333 0.01467811        chas=0
## 5     nox  0.17806333 1.23947911     nox=0.469
## 6      rm  0.18101333 0.59793068      rm=6.421