1 counterfacturals

library("counterfactuals")
library("iml")
library("randomForest")
library("mlr3")
library("mlr3learners")
library("data.table")

data(german, package = "rchallenge")  
credit = german[, c("duration", "amount", "purpose", "age", 
                    "employment_duration", "housing", "number_credits", 
                    "credit_risk")]
head(credit)
##   duration amount    purpose age employment_duration  housing number_credits
## 1       18   1049 car (used)  21              < 1 yr for free              1
## 2        9   2799     others  36    1 <= ... < 4 yrs for free            2-3
## 3       12    841 retraining  23    4 <= ... < 7 yrs for free              1
## 4       12   2122     others  39    1 <= ... < 4 yrs for free            2-3
## 5       12   2171     others  38    1 <= ... < 4 yrs     rent            2-3
## 6       10   2241     others  48              < 1 yr for free            2-3
##   credit_risk
## 1        good
## 2        good
## 3        good
## 4        good
## 5        good
## 6        good
dim(credit)
## [1] 1000    8
set.seed(20210816)
rf = randomForest::randomForest(credit_risk ~ ., data = credit[-998L,])

predictor = iml::Predictor$new(rf, type = "prob")
x_interest = credit[998L, ]
predictor$predict(x_interest)
##     bad  good
## 1 0.618 0.382
moc_classif = MOCClassif$new(
  predictor, epsilon = 0, fixed_features = c("age", "employment_duration"), 
  quiet = TRUE, termination_crit = "genstag", n_generations = 10L)
  
cfactuals = moc_classif$find_counterfactuals(
  x_interest, desired_class = "good", desired_prob = c(0.6, 1)
)

class(cfactuals)
## [1] "Counterfactuals" "R6"
print(cfactuals)
## 82 Counterfactual(s) 
##  
## Desired class: good 
## Desired predicted probability range: [0.6, 1] 
##  
## Head: 
##    duration amount purpose age employment_duration housing number_credits
## 1:       21   7460  others  30            >= 7 yrs     own              1
## 2:       21   7054  others  30            >= 7 yrs     own              1
## 3:       21   6435  others  30            >= 7 yrs     own              1
head(cfactuals$predict(), 3L)
##     bad  good
## 1 0.322 0.678
## 2 0.318 0.682
## 3 0.296 0.704
head(cfactuals$evaluate(show_diff = TRUE, 
                        measures = c("dist_x_interest", 
                                     "dist_target", "no_changed", "dist_train")), 3L)
##    duration amount purpose age employment_duration housing number_credits
## 1:       NA  -5220    <NA>  NA                <NA>    <NA>           <NA>
## 2:       NA  -5626    <NA>  NA                <NA>    <NA>           <NA>
## 3:       NA  -6245    <NA>  NA                <NA>    <NA>           <NA>
##    dist_x_interest no_changed dist_train dist_target
## 1:      0.04103193          1 0.04215022           0
## 2:      0.04422330          1 0.03895885           0
## 3:      0.04908897          1 0.03409318           0
cfactuals$subset_to_valid()
nrow(cfactuals$data)
## [1] 40
cfactuals$plot_freq_of_feature_changes(subset_zero = TRUE)

cfactuals$plot_parallel(feature_names = names(
  cfactuals$get_freq_of_feature_changes()),  digits_min_max = 2L)

cfactuals$plot_surface(feature_names = c("duration", "amount"))