1. Introduction


knitr::include_graphics("https://bigwhalelearning.files.wordpress.com/2014/11/titanic_heuristic.png")


1.1 What is Classification?

In statistics and machine learning, the goal of classification is to predict what category or class an observation belongs to based off of observations that we already know what category they belong to. For example, if we wanted to know if a person had blood type O, we can use a classifier algorithm to figure out the person’s unknown blood type based off individuals we know the blood type of.

Classification attempts to predict categories; this is opposed to regressions which try and measure continuous variables. A regression would be used to estimate the price of a house while a classifier would try to determine if a building was residential or commercial.

Classification is also distinct from clustering. Classification only works if the training data includes information about the different possible categories. For example, if you wanted to know if a person will vote Democrat or Republican, you need past data where people’s voting behavior is recorded. If you wanted to group voters who have similar beliefs, but no clearly identified categories, clustering is the preferred method. In other words, clustering is unsupervised learning while classification in supervised.


1.2 Goals:

Based on the level of information you come in with, my hope is that if you come in with no knowledge you feel comfortable with at least the level 1 goals, if you have level 1 knowledge you are comfortable with at least the level 2 goals, etc.

  • Level 1: Know when a classifier is appropriate for a given problem
  • Level 1: Understand how common classifiers in data science work
  • Level 2: Understand the assumptions and shortfalls each classifier has
  • Level 2: Be able to implement classifiers in your work
  • Level 3: Adopt a process oriented approach to model selection
  • Level 3: Evaluate the quality of a classifier


1.3 What’s Next:

Time allowing, my plan is to do similar series on other machine learning fields

  • Regressions: Predicting continuous data
  • Clustering: Grouping data based off of distributions when classes are unknown
  • Time Series: Making predictions off of trends


2. Classification Algorithms Covered in Series

2.1: Logistic Regression

knitr::include_graphics("https://littleml.files.wordpress.com/2016/06/lr_boundary_linear.png")

A logistic regression (or logit) is similar to a linear regression. It will use other variables to predict an output in an equation type format. The main difference between the two is the output for a linear regression is a continuous variable while the logistic regression is used to estimate a binary classification. Many of the assumptions held in a linear regression still hold for the logit - that parameters are linearly related to the output, error terms are normally distributed, etc.

An advantage a logistic regression has over classifiers is interpretabilty. The coefficients in a logistic regression can be transformed to probabilities. Meaning, you can describe each of the variables in your model in terms of the probability your chosen output will occur.


2.2: Classification Tree

knitr::include_graphics("https://www.edvancer.in/wp-content/uploads/2015/10/bdeb9fa398c892cfaf03c3da43ddbd6a.png")

A classification tree attempts to classify an observation based on creating splits in the data. In our Titanic example, the first cut the tree makes is whether the passenger is male or not. If the passenger is male, the second cut is whether the passenger’s age was less than 18. If not, the passenger was predicted to have died. You can follow the branches on the tree to the terminal points to see how the tree classifies each individual.

The classification tree, like the logistic regression, is also a fairly easy thing to explain to a non-technical audience. It’s advantages are that it doesn’t assume a linear separator between classes. In the decision space above, you can see the boundary is close to a square, not a line. This can be problematic with more branches as a tree can tend to overfit if left unpruned.


2.3: Support Vector Machine

knitr::include_graphics("http://www.statsoft.com/textbook/graphics/SVMIntro3.gif")

Support vector machines look for a boundary in the data. In the above example, the input space is transformed into the feature space where a linear boundary is drawn. Observations are either classified as being red or green based on what side of the line they appear. Thus, the white dot would be classified as being red. The line is drawn such that it is the same distance away from the closest red and green dot(s), which are called the support vectors.

Support vector machines tend to lead to better predictions than trees or logits, however they also have longer run times and are less interpretable to those without a technical background.


2.4: K-Nearest Neighbors

knitr::include_graphics("http://bdewilde.github.io/assets/images/2012-10-26-knn-example-ks.png")

KNN or K-nearest neighbors can be thought of like a voting system. Say we were trying to predict purple or yellow and we set our k = 1. Your training data would then look at all the area in your feature space and if the closest known observation was purple the unknown value would be classified as purple. If the closest known observation is yellow, the unknown observation would be classified as yellow.

Easy enough, what if k = 5? Then an unknown observation is classified as yellow if 3 or more of the 5 closest known observations are yellow. The majority of the “votes” closest to the unknown observation determines how its classified.

You can see in the above example that when k is too low our accuracy is a lot higher than when k is larger, but will be more prone to overfitting.


2.5: Naive Bayes

knitr::include_graphics("https://i.stack.imgur.com/rVlab.png")

Naive Bayes classifies an observation based on conditional probabilities. In a situation where we are trying to classify email as spam or not spam, we could look at the frequency where “win” or “buy” appear in spam or regular email. If we know that:

  • 80% of emails are spam (75% of them have the word “buy”, 40% of them have the word “win”)
  • 20% are regular email (12% of them have the word “buy”, 7% of them have the word “win”).

Then we can calculate the probability an email containing both “win” and “buy” is spam - which given our data is about 92% chance.


knitr::include_graphics("https://cdn-images-1.medium.com/max/1600/0*XCL_XAUjoO9xmt9W.gif")

knitr::include_graphics("https://cdn-images-1.medium.com/max/1600/0*0BBK3bcs6BrGnM9z.gif")


Naive Bayes tends to predict more poorly on large datasets than some of the other algorithms we’ve explored, but the run times are quicker and it needs less data than other algorithms. It too is difficult to explain to non-technical audiences.


2.6: Nural Net

knitr::include_graphics("https://i.pinimg.com/originals/d0/a7/11/d0a711f2e95d7159b6193a0982c09bfa.jpg")

Neural nets take in input variables (far left of the diagram) and assign weights to them to create a secondary layer of nodes (which have no intuitive meaning). If your network only has two layers, then one more set of weights are assigned to the second layer nodes in order to predict the class of your desired variable.

Neural nets tend to have a high accuracy, but can be difficult to evaluate what is going on and there is still work to be done on the statistical side of evaluating neural nets. They also tend to have long run times and require lots of data to work well.


3. Content Covered for Each Algorithm

For each of the above algorithms I will cover the following:


4. When to use Each Algorithm

Now that we covered the different classifiers this series will cover, you may be wondering how to decide between them. Below is a flow chart describing the situations when to use each of the classifiers depending on your project needs. You’ll notice that some of the terminal points of the flow diagram have multiple algorithms. This is because sometimes a model will perform too poorly for your needs. This flow ensures that you have a few algorithms to try out.

# Loading Package
library(DiagrammeR)

# Creating Flow Diagram
grViz("
digraph boxes_and_circles {
      
      # a 'graph' statement
      graph [overlap = false, fontsize = 20]
      
      # several 'node' statements
      node [shape = box,
      fontname = Helvetica]

     
      # several 'edge' statements

'START: Are you predicting a labled category' -> 'NO: These classifiers will be no help'
'START: Are you predicting a labled category' -> 'YES: Good, you are on the right path'
'YES: Good, you are on the right path' -> 'Is speed or accuracy more important?'
'Is speed or accuracy more important?' -> 'ACCURACY: Try kernal SVM, tree with meta-alg., or nural net'
'Is speed or accuracy more important?' -> 'SPEED: Do you need to explain model to lay people?'
'SPEED: Do you need to explain model to lay people?' -> 'YES: Try logit or tree'
'SPEED: Do you need to explain model to lay people?' -> 'NO: How big is your data?'
'NO: How big is your data?' -> 'HUGE: Try Naive Bayes'
'NO: How big is your data?' -> 'SMALLER: Try Naive Bayes, linear SVM, or KNN'
      }
      ")
# Flow Diagram Based on https://steemit.com/machine-learning/@idril/machine-learning-cheat-sheets-flowcharts-and-emojis

5. What Makes a Classifier Good?

In this section, we will discuss how to assess how strong our model is in terms of classifying our selected categories.


5.1 Data Splitting

This will be the method I have a feeling most people will be familiar with.

The procedure goes something like this:

  • Split your data into two parts, the training and test data
  • Train a model using your training data
  • Evaluate how strongly your model performed against the test data

The metrics we will be using to evaluate performance include:

  • Accuracy: Percent of the data classified correctly
  • Specificity: Number of false positives for each type
  • Sensitivity: Number of false negatives for each type

For other metrics reported in the output type ?confusionMatrix after loading the caret package

Below is an example using a Naive Bayes classifier:

# Load Packages
library(caret)
library(klaR)

# Setting Seed for Reproducability
set.seed(3671205)

# Load Data
data(iris)

# Define Training Control
split=0.80
trainIndex <- createDataPartition(iris$Species, p=split, list=FALSE)
data_train <- iris[trainIndex,]
data_test <- iris[-trainIndex,]

# Train Model
model <- NaiveBayes(Species~., data=data_train)

# Make Predictions
x_test <- data_test[,1:4]
y_test <- data_test[,5]
predictions <- predict(model, x_test)

# Assess Model
confusionMatrix(predictions$class, y_test)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0          9         0
##   virginica       0          1        10
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9667          
##                  95% CI : (0.8278, 0.9992)
##     No Information Rate : 0.3333          
##     P-Value [Acc > NIR] : 2.963e-13       
##                                           
##                   Kappa : 0.95            
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.9000           1.0000
## Specificity                 1.0000            1.0000           0.9500
## Pos Pred Value              1.0000            1.0000           0.9091
## Neg Pred Value              1.0000            0.9524           1.0000
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.3000           0.3333
## Detection Prevalence        0.3333            0.3000           0.3667
## Balanced Accuracy           1.0000            0.9500           0.9750


5.2 K-Fold Cross Validation

K-fold cross validation is similar to data splitting. However, instead of separating our data into 2 parts (test and train), we split our data into k parts where k is a number between 3 and the number of observations in your data minus one.

After your data is split into k “folds” or parts, you train your model on all but one fold and cross validate on the remaining fold. This process is repeated until all folds have been left out. In example, if we were using a 3-Fold cross validation method, our process would look something like this:

  • Split data into three parts; data1, data2, data3
  • Build estimate1 using data1 and data2, test how well it predicts data in data3
  • Build estimate2 using data1 and data3, test how well it predicts data in data2
  • Build estimate3 using data2 and data3, test how well it predicts data in data1
  • Create a confusion matrix with the predicted vs actual calculated in the three steps above

This will have a longer run time than the data splitting version but will better assess the quality of our model.

# Load Packages
library(caret)
library(rpart)

# Setting Seed for Reproducability
set.seed(3671205)

# Load Data
data(iris)

# Define Training Control
train_control<- trainControl(method="cv", number=10, savePredictions = TRUE)

# Train Model
model<- train(Species ~ ., data=iris, trControl=train_control, method = "nb")

# Assess Model
confusionMatrix(model$pred$pred, model$pred$obs)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa        100          0         0
##   versicolor      0         94         7
##   virginica       0          6        93
## 
## Overall Statistics
##                                          
##                Accuracy : 0.9567         
##                  95% CI : (0.927, 0.9767)
##     No Information Rate : 0.3333         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.935          
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.9400           0.9300
## Specificity                 1.0000            0.9650           0.9700
## Pos Pred Value              1.0000            0.9307           0.9394
## Neg Pred Value              1.0000            0.9698           0.9652
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.3133           0.3100
## Detection Prevalence        0.3333            0.3367           0.3300
## Balanced Accuracy           1.0000            0.9525           0.9500


5.3 Leave-One-Out Cross Validation

This version goes one step further than k-fold cross validation. In “leave one out,” as the name implies, we estimate a model using all the data except one and do this for each observation. You can think of this as like k-fold validation except that the number of folds is equal to your sample size. Naturally, this will increase the computation time even more and may not be feasible for complex algorithms or large datasets, but will even better assess the quality of your model than the other two methods.

# Load Packages
library(caret)

# Setting Seed for Reproducability
set.seed(3671205)

# Load Data
data(iris)

# Define Training Control
train_control <- trainControl(method="LOOCV")

# Train Model
model <- train(Species~., data=iris, trControl=train_control, method="nb")

# Assess Model
confusionMatrix(model$pred$pred, model$pred$obs)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa        100          0         0
##   versicolor      0         94         7
##   virginica       0          6        93
## 
## Overall Statistics
##                                          
##                Accuracy : 0.9567         
##                  95% CI : (0.927, 0.9767)
##     No Information Rate : 0.3333         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.935          
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.9400           0.9300
## Specificity                 1.0000            0.9650           0.9700
## Pos Pred Value              1.0000            0.9307           0.9394
## Neg Pred Value              1.0000            0.9698           0.9652
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.3133           0.3100
## Detection Prevalence        0.3333            0.3367           0.3300
## Balanced Accuracy           1.0000            0.9525           0.9500


For other Cross Validation methods, see bootstrapping and repeated k-fold cross validation here