counterfacturals
library("counterfactuals")
library("iml")
library("randomForest")
library("mlr3")
library("mlr3learners")
library("data.table")
data(german, package = "rchallenge")
credit = german[, c("duration", "amount", "purpose", "age",
"employment_duration", "housing", "number_credits",
"credit_risk")]
head(credit)
## duration amount purpose age employment_duration housing number_credits
## 1 18 1049 car (used) 21 < 1 yr for free 1
## 2 9 2799 others 36 1 <= ... < 4 yrs for free 2-3
## 3 12 841 retraining 23 4 <= ... < 7 yrs for free 1
## 4 12 2122 others 39 1 <= ... < 4 yrs for free 2-3
## 5 12 2171 others 38 1 <= ... < 4 yrs rent 2-3
## 6 10 2241 others 48 < 1 yr for free 2-3
## credit_risk
## 1 good
## 2 good
## 3 good
## 4 good
## 5 good
## 6 good
## [1] 1000 8
set.seed(20210816)
rf = randomForest::randomForest(credit_risk ~ ., data = credit[-998L,])
predictor = iml::Predictor$new(rf, type = "prob")
x_interest = credit[998L, ]
predictor$predict(x_interest)
## bad good
## 1 0.618 0.382
moc_classif = MOCClassif$new(
predictor, epsilon = 0, fixed_features = c("age", "employment_duration"),
quiet = TRUE, termination_crit = "genstag", n_generations = 10L)
cfactuals = moc_classif$find_counterfactuals(
x_interest, desired_class = "good", desired_prob = c(0.6, 1)
)
class(cfactuals)
## [1] "Counterfactuals" "R6"
## 82 Counterfactual(s)
##
## Desired class: good
## Desired predicted probability range: [0.6, 1]
##
## Head:
## duration amount purpose age employment_duration housing number_credits
## 1: 21 7460 others 30 >= 7 yrs own 1
## 2: 21 7054 others 30 >= 7 yrs own 1
## 3: 21 6435 others 30 >= 7 yrs own 1
head(cfactuals$predict(), 3L)
## bad good
## 1 0.322 0.678
## 2 0.318 0.682
## 3 0.296 0.704
head(cfactuals$evaluate(show_diff = TRUE,
measures = c("dist_x_interest",
"dist_target", "no_changed", "dist_train")), 3L)
## duration amount purpose age employment_duration housing number_credits
## 1: NA -5220 <NA> NA <NA> <NA> <NA>
## 2: NA -5626 <NA> NA <NA> <NA> <NA>
## 3: NA -6245 <NA> NA <NA> <NA> <NA>
## dist_x_interest no_changed dist_train dist_target
## 1: 0.04103193 1 0.04215022 0
## 2: 0.04422330 1 0.03895885 0
## 3: 0.04908897 1 0.03409318 0
cfactuals$subset_to_valid()
nrow(cfactuals$data)
## [1] 40
cfactuals$plot_freq_of_feature_changes(subset_zero = TRUE)

cfactuals$plot_parallel(feature_names = names(
cfactuals$get_freq_of_feature_changes()), digits_min_max = 2L)

cfactuals$plot_surface(feature_names = c("duration", "amount"))
