The theory and implementation of SHAP (SHapley Additive exPlanations)

Theory

SHapley Additive exPlanations (SHAP) is a method to understand how our AI model came to a certain decision. For example, if your task is to make AI for the loan screening process, you need to make sure why your AI came up to the conclusion because your applicant would desperate to know why they failed. The idea of SHAP is to allocate/calculate credits for each explanatory variable on the prediction. For example, in cooperative game theory, a player is able to choose “join” or “not join” the game. And if all players are “join” the game except player A, the difference between the result of the game within all players and the game without player A is going to be credited to player A. SHAP does a similer process to the model to help us to understand the behavior of the model.

Additive Feature Attribution Method

SHAP computes credits by calculating the difference of the overall prediction and an individual prediction. For example, a model \(\hat{f}(X)\) with explanatory variable \(X = (X_1,...,X_J)\). So explanatory variables for observation \(i\) are \(x_i=(x_{i,1}...,x_{i,J})\). In SHAP, the credits are break down into the addition form.

\[ \hat{f}(x_i) - \mathbb{E}\big[\hat{f}(X)\big] = \sum_{j = 1}^{J} \phi_{i,j} \]

Breaking down into summation from called Additive Feature Attribution Method and SHAP is one of the Additive Feature Attribution Methods. Often, by setting \(\phi_0 = \mathbb{E}\big[\hat{f}(X)\big]\) SHAP is notated as

\[ \hat{f}(x_i) = \phi_{0} - \sum_{j = 1}^{J} \phi_{i,j} \]

For example, let’s say our goal is to predict annual salary from education, job title, job position, and English skill. The prediction formula would be

\[ \hat{f}(x_i) = \phi_0 + \phi_i, education + \phi_i title + \phi_i position + \phi_i english \]

In this particular observation, we get someone who got a Master’s degree, works as Data Scientist, is a Manager, and does not speak English. And the predicted salary is 10 million yen.

According to our AI model, if he practices his English, he has more chance to be predicted a higher salary.

Shapley Value

The Shapley value is a solution concept used in game theory that involves fairly distributing both gains and costs to several actors working in a coalition. Game theory is when two or more players or factors are involved in a strategy to achieve a desired outcome or payoff. -Will Kenton(investopedia).

Let’s say there are three players A, B, and C in some cooperative game, and we want to somehow calculate the contributions of each player. Here is a table of final scores with a list of players in the game.

Player Score
A 6
B 4
C 2
A,B 20
A,C 15
B,C 10
A,B,C 24

To calculate the contribution, we use marginal contribution. Marginal contribution calculates one’s effort by getting the average of the difference between the game scores with and without a certain player. For example, if you want to calculate marginal contribution for player A,

Game 1 Player A plays the game aloe, A gets a score of 6. None -> A

\[ Contribution_A = 6 - 0 = 6 \]

Game 2
Player B plays the game aloe(4), but A gets into the game. A and B gets score of 20. B -> A,B
\[ Contribution_A = 20 - 4 = 16 \]

Game 3
Player C plays the game aloe(2), but A gets into the game. A and C gets score of 15. C -> A,C
\[ Contribution_A = 15 - 2 = 13 \]

Game 4
Player B and C plays the game (10), A gets into the game. A,B, and C gets score of 24. B,C -> A,B,C

\[ Contribution_A = 24 - 10 = 14 \]

Pay attention to the fact that the contribution of A changes who’s in the game already. Therefore, we need to average each contribution.

Here is a table of marginal contributions with all possible participation orders.

Participation Order Contribution A Contribution B Contribution C
A-B-C 6 14 4
A-C-B 6 9 9
B-A-C 16 4 4
B-C-A 14 4 6
C-A-B 13 9 2
C-B-A 14 8 2

By taking the average we can calculate the Shapley Value

\[ \phi_{A} = \dfrac{6+6+16+14+13+14}{6}=11.5 \]

Formula for Shapley Value

\[ \phi_{j} = \dfrac{1}{|\mathbf{J}|!} \sum_{\mathbf{S}\subseteq\mathbf{J}-j} (|\mathbf{S}|!(|\mathbf{J}|-|\mathbf{S}|-1)!) (v(\mathbf{S}\cup\{j\})-v(\mathbf{S})) \]

\(\mathbf{J} = \{1,...,J\}\) is a set of players.
- In the case above, \(\mathbf{J} = \{A,B,C\}\)

\(|\mathbf{J}|\) is the number of components in the set \(\mathbf{J}\).
- In the case above, \(|\mathbf{J}| =3\)

\(\mathbf{S}\) is a set of sets that deducts player \(j\) from \(\mathbf{J}\).
- Thinking about shapely value of A, \(\mathbf{S}\) is \(\emptyset,\{B\},\{C\},\{B,C\}\)

\(|\mathbf{S}|\) is number of components in the set \(\mathbf{S}\).
- In the case \(\emptyset,\{B\},\{C\},\{B,C\}\), \(|\mathbf{S}|\) would be \(0,1,1,2\)

\(v\) is a game score.
- If A and B are in the game, the score would be notated as \(v(\{A,B\})\)

\(v(\mathbf{S}\cup\{j\}) - v(\mathbf{S})\) is difference of the scores when player \(j\) is in and out of the game.

Get Dataset

# Set up
library(mlbench)
library(tidymodels)
library(DALEX)
library(ranger)
library(Rcpp)
library(corrplot)
library(ggplot2)
library(gridExtra)
library(SHAPforxgboost)
library(xgboost)

data("BostonHousing")
df = BostonHousing %>% select(where(is.numeric))

Obserview of the Dataset

Here are overview of the dataset

head(df)
##      crim zn indus   nox    rm  age    dis rad tax ptratio      b lstat medv
## 1 0.00632 18  2.31 0.538 6.575 65.2 4.0900   1 296    15.3 396.90  4.98 24.0
## 2 0.02731  0  7.07 0.469 6.421 78.9 4.9671   2 242    17.8 396.90  9.14 21.6
## 3 0.02729  0  7.07 0.469 7.185 61.1 4.9671   2 242    17.8 392.83  4.03 34.7
## 4 0.03237  0  2.18 0.458 6.998 45.8 6.0622   3 222    18.7 394.63  2.94 33.4
## 5 0.06905  0  2.18 0.458 7.147 54.2 6.0622   3 222    18.7 396.90  5.33 36.2
## 6 0.02985  0  2.18 0.458 6.430 58.7 6.0622   3 222    18.7 394.12  5.21 28.7

medv is our response variable, We predict this.

hist(df$medv,breaks = 20, main = 'Histgram of medv', xlab = 'medv ($ in 1,000)') 

Build a Model

We won’t cover building a model in this article. I used XGBoost model.

split = initial_split(df , 0.8)
train = training(split)
test = testing(split)

params = list(objective = "reg:squarederror")

mod = xgboost::xgboost(data = as.matrix(train %>% select(-medv)), 
                       label = as.matrix(as.numeric(train$medv)), 
                       params = params, nrounds = 200,
                       verbose = FALSE, 
                       early_stopping_rounds = 8)

fit = predict(mod, as.matrix(test %>% select(-medv)))

Predict medv

result = test %>% 
  select(medv) %>% 
  bind_cols(predict(mod, as.matrix(test %>% select(-medv))))

metrics = metric_set(rmse, rsq)

result %>% 
  metrics(medv, ...2)
## # A tibble: 2 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard       2.72 
## 2 rsq     standard       0.898

Interpre SHAP

Use the function explain to create an explainer object that helps us to interpret the model.

explainer = fit %>% 
  DALEX::explain(
    model = mod,
    data = as.matrix(test %>% select(-medv)),
    y = as.matrix(test$medv),
    predict_function = function(model, data)
                  predict(model, data)
  )
## Preparation of a new explainer is initiated
##   -> model label       :  xgb.Booster  (  default  )
##   -> data              :  102  rows  12  cols 
##   -> target variable   :  102  values 
##   -> predict function  :  function(model, data) predict(model, data) 
##   -> predicted values  :  Predict function column set to:  21.10909 19.49673 20.97629 19.10846 14.27746 31.05584 17.9938 19.97296 15.75527 29.41286 20.01282 22.74503 22.00158 20.09155 21.93301 25.55645 23.38748 26.19921 20.92769 26.76099 20.18315 20.04681 18.10844 20.82701 21.94165 18.52046 20.14966 18.06423 17.11569 15.97257 20.43725 15.86472 16.61307 38.27946 21.2997 22.72443 19.6216 21.84618 26.11911 22.26462 43.31647 39.46843 37.01257 34.85752 20.91882 27.01605 22.15653 26.6434 21.70217 27.03636 20.52796 22.32437 20.71864 25.39752 31.33245 32.52179 44.2351 46.61555 23.1122 24.36021 22.65283 33.30468 19.04642 21.81393 20.04786 24.88386 20.95871 17.89497 21.21857 30.42643 27.29062 22.38077 21.05732 25.61748 20.68874 41.36121 39.39579 12.46698 9.546707 9.370389 9.340064 12.87758 9.401181 6.020366 19.74178 11.45585 12.53807 16.00513 18.24148 18.50839 13.16901 21.78896 20.15221 16.77411 15.27773 18.95676 27.98423 31.68075 21.93552 13.71297 20.19643 20.25469 (  OK  )
##   -> model_info        :  package Model of class: xgb.Booster package unrecognized , ver. Unknown , task regression (  default  ) 
##   -> predicted values  :  numerical, min =  6.020366 , mean =  22.38712 , max =  46.61555  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -8.354686 , mean =  0.1805289 , max =  10.60421  
##   A new explainer has been created!

Use predict_parts function to get SHAP plot.

  • new_observation: Set a observation which you want to calculate SHAP (In this case 1)
  • type: Set the method
  • B: Number of trials of picking ways of “Participation Order”

In this plot, we are extracting 1st observation of our test set. Whose explanatory variables and predictor variable are below.

test %>% dplyr::slice(1)
##       crim   zn indus   nox    rm age    dis rad tax ptratio     b lstat medv
## 13 0.09378 12.5  7.87 0.524 5.889  39 5.4509   5 311    15.2 390.5 15.71 21.7
shap = predict_parts(
  explainer,
  new_observation = as.matrix(test %>% select(-medv) %>% dplyr::slice(1)),
  type = "shap",
  B = 25)
plot(shap)

By looking at the plot, we can tell that lstat has the highest contribution to the prediction. To interpret, lstat contributes medv to be high for the positive direction, and next contributor is rm.

FYI

Method Function
Permutation Feature Importance(PFI) model_parts()
Partial Dependence(PD) model_profile()
Individual Conditional Expectation(ICE) predict_profile()
SHAP predict_parts()

Gloval Feature Importance

So far, we have discussed instance-level exploration. Now, let’s briefly touch on dataset-level exploration. While we won’t go into detail on this page, the core idea involves global-level explanatory variable analysis. Some methods, such as running SHAP for the entire dataset, can achieve this, but they are computationally expensive and time-consuming.

The key concept is to assess how a model’s performance changes when the influence of a specific explanatory variable, or a group of variables, is removed. This removal is achieved through perturbations, such as resampling from the empirical distribution or permuting the variable’s values.

loss_root_mean_square(observed = as.matrix(test %>% select(-medv)), 
                   predicted = predict(mod, as.matrix(test %>% select(-medv))))
## [1] 160.9889
vip.50 = model_parts(explainer = explainer, 
        loss_function = loss_root_mean_square,
                               B = 50,
                            type = "difference")

plot(vip.50) +
  ggtitle("Mean variable-importance over 50 permutations", "") 

Just like that one observation, lstat and rm are most important feature in the whole dataset.

Conclution

SHAP is a sufficient method to observe variable importance on each observation. At the same time, SHAP also can be helpful to get overview of the whole model. However, SHAP cannot explain the relation ship between explanatory variable and responce variable, which can be explain with ICE.