Introduction
If you think you know everything about the mlr package, you might be wrong. Though we have discovered most of sophisticated features of this package, like model wrapping, imputation, grid based tuning…, there are some hidden functions that we should learn. Among them, plotLearnerPrediction() function might be the most interesting one. This function has never been documented by anyone, even in the mlr’s official website.
In the present tutorial, I will show you how to apply this function for making the beautiful plot of machine learning’s prediction.
The plotLearnerPrediction() function consists of an integrated procedure:
Training and cross-validation: The learner will be trained many times upon 75% of the original dataset in classification Task. For each iteration, trained model will be tested on either train subset (75%) and test subset (remaning 25% of original dataset). The model’s performance (by default: mean missclassification error- mmce; it could be customised by user to include more performance metrics) will be averaged. The learner’s name and its cross-validation result will be reported as subtitle on the final graph.
A two dimensional data space will be set from X and Y variables as introduced by the ‘features’ arguments. Then a scatter dot plot will be generated by geom_point() function in ggplot2.
Prediction boundaries will be generated using geom_tile() function in ggplot2 with either predicted probability (when available) or predicted classes. When probabilities are used, the prediction will be labelled by color fill and alpha is coded by probability values.
The final output consists of a ggplot2 object and therefore could be customised by scale_fill_color() or ggplot2’s themes setting.
Example: The IRIS classification task
The iris dataset was introduced by Ronald Fisher, a biologist and statistician in 1936. The data set consists of 50 observations from each of three Iris species (setosa, virginica and versicolor). A multi-class classification task could be set from this data, our goal is to classify the Iris species in terms of the length and the width of the sepals and petals, in centimetres.
Figure: iris data
#Data preparation
data(iris)
Here is the association between Sepal’s length and width and the two dimensional distribution of iris species in our datset:
library(ggplot2)
ggplot(iris,aes(x=Sepal.Length,y=Sepal.Width))+stat_density2d(geom="polygon",aes(fill=iris$Species,alpha = ..level..))+geom_point(aes(shape=Species),color="black",size=2)+theme_bw()+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))
The same analysis could be done for Petal’s length and width:
ggplot(iris,aes(x=Petal.Length,y=Petal.Width))+stat_density2d(geom="polygon",aes(fill=iris$Species,alpha = ..level..))+geom_point(aes(shape=Species),color="black",size=2)+theme_bw()+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))
As usual, we will set a classification task in mlr then introduce some learners
#Make a multiclass classification task in mlr
library(mlr)
taskiris=makeClassifTask(id="iris",data=iris,target="Species")
# Making 7 different Learners (algorithm)
learnerCART=makeLearner(id="CART","classif.rpart", predict.type = "prob")
learnerRF=makeLearner(id="RF","classif.randomForestSRC", predict.type = "prob")
learnerSVM=makeLearner(id="SVM","classif.svm", predict.type = "prob")
learnerGBM=makeLearner(id="GBM","classif.gbm", predict.type = "prob")
learnerGLMN=makeLearner(id="Elasticnet","classif.glmnet", predict.type = "prob")
learnerKNN=makeLearner(id="KNN","classif.knn")
learnerLDA=makeLearner(id="LDA","classif.lda", predict.type = "prob")
Note: Those algorithms consist of : Decision tree (CART), Random Forest SRC, Gradient boosting machine (GBM), Elastic net (GLMN), Support vector machine (SVM), k-nearest neighbors (KNN) and Linear discriminant analysis (LDA). All of them can handle multiclass classification problem in 7 different ways. KNN does not support probability classification.
The syntax of plotLearnerPrediction( ) function:
“plotLearnerPrediction(learner, task, features = NULL, measures, cv = 10L, …, gridsize, pointsize = 2, prob.alpha = TRUE, se.band = TRUE, err.col =”white“, greyscale = FALSE)”
Where:
learner is object’s name for learner task is the object name for task
features argument : up to 2 features could be introduced here. By default the first 2 features are used
measures indicate Performance measure(s) to evaluate. Default is the default measure for the task
cv for setting the cross-validation and reporting its result as plot title. Number of folds. cv=0 means no CV. Default is 10.
gridsize is the grid resolution per axis for background predictions. Default is 100 for 2D.
Pointsize for ggplot2 geom_point for data points. Default is 2.
prob.alpha is a logical argument, for setting alpha value of background to probability for predicted class? Allows visualization of “confidence” for prediction. If not, only a constant color is displayed in the background for the predicted label. Default is TRUE.
se.band: For regression in 1D: Show band for standard error estimation? Default is TRUE.
err.col: For classification, Color of misclassified data points. Default is “white”
greyscale is a logical argument: Should the plot be greyscale completely? Default is FALSE
CART algorithm
plotLearnerPrediction(learnerCART,taskiris,features=c("Sepal.Length","Sepal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
plotLearnerPrediction(learnerCART,taskiris,features=c("Petal.Length","Petal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
Support vector machine algorithm
plotLearnerPrediction(learnerSVM,taskiris,features=c("Sepal.Length","Sepal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
plotLearnerPrediction(learnerSVM,taskiris,features=c("Petal.Length","Petal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
Gradient boosting machine algorithm
plotLearnerPrediction(learnerGBM,taskiris,features=c("Sepal.Length","Sepal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
plotLearnerPrediction(learnerGBM,taskiris,features=c("Petal.Length","Petal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
## Distribution not specified, assuming multinomial ...
Elastic net (logistic) algorithm
plotLearnerPrediction(learnerGLMN,taskiris,features=c("Sepal.Length","Sepal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
plotLearnerPrediction(learnerGLMN,taskiris,features=c("Petal.Length","Petal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
KNN algorithm
plotLearnerPrediction(learnerKNN,taskiris,features=c("Sepal.Length","Sepal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
plotLearnerPrediction(learnerKNN,taskiris,features=c("Petal.Length","Petal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
LDA algorithm
plotLearnerPrediction(learnerLDA,taskiris,features=c("Sepal.Length","Sepal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
plotLearnerPrediction(learnerLDA,taskiris,features=c("Petal.Length","Petal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
Random Forest algorithm
plotLearnerPrediction(learnerRF,taskiris,features=c("Sepal.Length","Sepal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
plotLearnerPrediction(learnerRF,taskiris,features=c("Petal.Length","Petal.Width"),cv=100L,gridsize=100)+scale_fill_manual(values=c("#ff0061","#11a6fc","#ffae00"))+theme_bw()
Conclusion
plotLearnerPrediction() is a hidden tool in mlr package. This useful function allows to generate beautiful plots for the illustration purpose. These plots provide many information, such as:
A visual perception of model’s performance: its ability to classify the instances into two or more classes, the correct classification and error rates.
A visual presentation of the underlying mechanism of the model, via the prediction boundaries
Assocation between two features and their contribution to model’s prediction.
Numerical result of cross-validation: averaged model’s performance metrics
Thank you for joining us and see you in the next tutorial :)
END