让我们看看DALEX包的一个例子,用于泰坦尼克号数据集生存问题的分类模型。在这里,我们使用的是DALEX软件包中可用的数据集titanic_imputed。请注意,这个数据是从 “stablelearner”包中复制过来的,并且为了实用性而改变。
library("DALEX")
head(titanic_imputed)
使用随机森林模型。
# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
fare + sibsp + parch,
data = titanic_imputed, probability = TRUE)
model_titanic_rf
Ranger result
Call:
ranger(survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic_imputed, probability = TRUE)
Type: Probability estimation
Number of trees: 500
Sample size: 2207
Number of independent variables: 7
Mtry: 2
Target node size: 10
Variable importance mode: none
Splitrule: gini
OOB prediction error (Brier s.): 0.1420171
create a DALEX explainer for random forest model 创建一个随机森林模型的 “DALEX”解释器。
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic_imputed[,-8],
y = titanic_imputed[,8],
label = "Random Forest")
Preparation of a new explainer is initiated
-> model label : Random Forest
-> data : 2207 rows 7 cols
-> target variable : 2207 values
-> predict function : yhat.ranger will be used ( [33m default [39m )
-> predicted values : No value for predict function target column. ( [33m default [39m )
-> model_info : package ranger , ver. 0.12.1 , task classification ( [33m default [39m )
-> predicted values : numerical, min = 0.01407589 , mean = 0.3216433 , max = 0.9894214
-> residual function : difference between y and yhat ( [33m default [39m )
-> residuals : numerical, min = -0.7870102 , mean = 0.0005134651 , max = 0.8832034
[32m A new explainer has been created! [39m
passanger <- titanic_imputed[sample(nrow(titanic_imputed), 1) ,-8]
passanger
特征重要性解释显示了模型中所有变量的重要性。 Use the feature_importance() explainer to present importance of particular features. Note that type = "difference" normalizes dropouts, and now they all start in 0. Feature importance explanation shows the importance of all the model’s variables. As it is a global explanation technique, no passanger need to be specified.
library("ingredients")
fi_rf <- feature_importance(explain_titanic_rf)
head(fi_rf)
plot(fi_rf)
Function describe() easily describes which variables are the most important. Argument nonsignificance_treshold as always sets the level above which variables become significant. For higher treshold, less variables will be described as significant.
describe(fi_rf)
The number of important variables for Random Forest's prediction is 4 out of 7.
Variables gender, class, fare have the highest importantance.
As we see the most important feature is gender. Next three importnat features are class, age and fare. Let’s see the link between model response and these features.
Such univariate relation can be calculated with partial_dependence().
Kids 5 years old and younger have much higher survival probability.
pp_age <- partial_dependence(explain_titanic_rf, variables = c("age", "fare"))
head(pp_age)
plot(pp_age)
Ceteris Paribus profiles shows how the model’s input changes with the change of a specified variable.
perturbed_variable <- "class"
cp_rf <- ceteris_paribus(explain_titanic_rf,
passanger,
variables = perturbed_variable)
plot(cp_rf, variable_type = "categorical")
For a user with no experience, interpreting the above plot may be not straightforward. Thus we generate a natural language description in order to make it easier. 对于一个没有经验的用户来说,解读上面的图可能并不直接。因此,该包可以生成文字解释,以使其更容易。
describe(cp_rf)
For the selected instance, prediction estimated by Random Forest is equal to 0.145.
Model's prediction would increase substantially if the value of class variable would change to "1st", "2nd".
The largest change would be marked if class variable would change to "1st".
Other variables are with less importance and they do not change prediction by more than 0.09%.
自然语言描述中,各种参数可以修改。 Natural lannguage descriptions should be flexible in order to provide the desired level of complexity and specificity. Thus various parameters can modify the description being generated.
describe(cp_rf,
display_numbers = TRUE,
label = "the probability that the passanger will survive")
Random Forest predicts that for the selected instance, the probability that the passanger will survive is equal to 0.145
The most important change in Random Forest's prediction would occur for class = "1st". It increases the prediction by 0.593.
The second most important change in the prediction would occur for class = "2nd". It increases the prediction by 0.571.
The third most important change in the prediction would occur for class = "victualling crew". It increases the prediction by 0.081.
Other variable values are with less importance. They do not change the the probability that the passanger will survive by more than 0.075.
Please note, that describe() can handle only one variable at a time, so it is recommended to specify, which variables should be described.
describe(cp_rf,
display_numbers = TRUE,
label = "the probability that the passanger will survive",
variables = perturbed_variable)
Random Forest predicts that for the selected instance, the probability that the passanger will survive is equal to 0.145
The most important change in Random Forest's prediction would occur for class = "1st". It increases the prediction by 0.593.
The second most important change in the prediction would occur for class = "2nd". It increases the prediction by 0.571.
The third most important change in the prediction would occur for class = "victualling crew". It increases the prediction by 0.081.
Other variable values are with less importance. They do not change the the probability that the passanger will survive by more than 0.075.
Continuous variables are described as well.
perturbed_variable_continuous <- "age"
cp_rf <- ceteris_paribus(explain_titanic_rf,
passanger)
plot(cp_rf, variables = perturbed_variable_continuous)
describe(cp_rf, variables = perturbed_variable_continuous)
Random Forest predicts that for the selected instance, prediction is equal to 0.145
The highest prediction occurs for (age = 0.1666666667), while the lowest for (age = 42).
Breakpoint is identified at (age = 4).
Average model responses are *higher* for variable values *lower* than breakpoint (= 4).
Ceteris Paribus profiles are described only for a single observation. If we want to access the influence of more than one observation, we need to describe dependence profiles.
pdp <- aggregate_profiles(cp_rf, type = "partial")
plot(pdp, variables = "fare")
describe(pdp, variables = "fare")
Random Forest's mean prediction is equal to 0.145.
The highest prediction occurs for (fare = 0), while the lowest for (fare = 27.18).
Breakpoint is identified at (fare = 27.18).
Average model responses are *lower* for variable values *higher* than breakpoint (= 27.18).
pdp <- aggregate_profiles(cp_rf, type = "partial", variable_type = "categorical")
plot(pdp, variables = perturbed_variable)
describe(pdp, variables = perturbed_variable)
Random Forest's mean prediction is equal to 0.145.
Model's prediction would increase substantially if the value of class variable would change to "1st", "2nd", "victualling crew", "restaurant staff", "engineering crew".
The largest change would be marked if class variable would change to "victualling crew".
Other variables are with less importance and they do not change prediction by more than 0.09%.
cp_age <- conditional_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(cp_age)
ap_age <- accumulated_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(ap_age)
Let’s see break down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
First Ceteris Paribus Profiles for numerical variables
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- ceteris_paribus(explain_titanic_rf, new_passanger)
plot(sp_rf) +
show_observations(sp_rf)
And for selected categorical variables. Note, that sibsp is numerical but here is presented as a categorical variable.
plot(sp_rf,
variables = c("class", "embarked", "gender", "sibsp"),
variable_type = "categorical")
It looks like the most important feature for this passenger is age and sex. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite of being a male.