Which Model to Use?

Data Mining Models and Code
model response variable (y) type predictors missing value handling dummify predictors needed? code
multiple linear regression quantitative any yes yes lm(y~., data=D)
binary logistic binary (0/1) any yes yes glm(y~., data = D, family = ‘binomial’)
ordinal logistic factor with ordered levels any yes yes MASS::polr(factor(y, ordered=TRUE)~., data=D)
nominal logistic factor with unordered levels any yes yes nnet::multinom(factor(y, ordered=FALSE)~., data=D)
knn for classification or regression any use all dummies if needed yes yes knn(X of train, X of valid, cl = y, k = #of neighbors)
classification tree any yes yes rpart(y~., data = D, method = ‘class’)
regression tree quantitative any yes yes rpart(y~., data = D, method = ‘anova’)
bagging or boosting any yes yes randomForest::randomForest(factor(y)~., data = D, ntree = 500, importance=TRUE) or adabag::boosting(factor(y)~., data = D)
nn all dummies for y dummies except one yes yes neuralnet::neuralnet(all dummies ~ ., data=D, hiiden = c(k1,k2,…), linear.output=FALSE for categorical y and TRUE for quantitative y
hierarchical and k-means clustering no y needed no predictor yes yes hclust(dist(D), method = ‘complete or ave or ward.D’) or kmeans(x = numeric df) or VarSelLCM::VarSelCluster(x = D, gvals = #ofclusters, vbleSelec = TRUE)
Note:
Refer to textbook for examples.

Multiple Linear Regression Model

We give examples for each model. We start with the multiple linear regression model which has a quantitative response.

This was demonstrated in Section 1.1 of https://rpubs.com/scsustat/Performance.

Binary Logistic Regression

This was demonstrated in Section 3 of https://rpubs.com/scsustat/Performance.

The following provides another example using the Attrition data from https://web.stcloudstate.edu/szhang/

First, let’s take a kook at the data:

Attrition = data.table::fread("/Users/home/downloads/Attrition Data.csv") %>% as.data.frame()

summary(Attrition)
##       Age         Attrition          Department        DistanceFromHome
##  Min.   :18.00   Length:1470        Length:1470        Min.   : 1.000  
##  1st Qu.:30.00   Class :character   Class :character   1st Qu.: 2.000  
##  Median :36.00   Mode  :character   Mode  :character   Median : 7.000  
##  Mean   :36.92                                         Mean   : 9.193  
##  3rd Qu.:43.00                                         3rd Qu.:14.000  
##  Max.   :60.00                                         Max.   :29.000  
##    Education     EducationField     EnvironmentSatisfaction JobSatisfaction
##  Min.   :1.000   Length:1470        Min.   :1.000           Min.   :1.000  
##  1st Qu.:2.000   Class :character   1st Qu.:2.000           1st Qu.:2.000  
##  Median :3.000   Mode  :character   Median :3.000           Median :3.000  
##  Mean   :2.913                      Mean   :2.722           Mean   :2.729  
##  3rd Qu.:4.000                      3rd Qu.:4.000           3rd Qu.:4.000  
##  Max.   :5.000                      Max.   :4.000           Max.   :4.000  
##  MaritalStatus      MonthlyIncome   NumCompaniesWorked WorkLifeBalance
##  Length:1470        Min.   : 1009   Min.   :0.000      Min.   :1.000  
##  Class :character   1st Qu.: 2911   1st Qu.:1.000      1st Qu.:2.000  
##  Mode  :character   Median : 4919   Median :2.000      Median :3.000  
##                     Mean   : 6503   Mean   :2.693      Mean   :2.761  
##                     3rd Qu.: 8379   3rd Qu.:4.000      3rd Qu.:3.000  
##                     Max.   :19999   Max.   :9.000      Max.   :4.000  
##  YearsAtCompany  
##  Min.   : 0.000  
##  1st Qu.: 3.000  
##  Median : 5.000  
##  Mean   : 7.008  
##  3rd Qu.: 9.000  
##  Max.   :40.000
for (v in colnames(Attrition)){
    z = Attrition[, v]
    if (is.numeric(z)) {
        boxplot(z, xlab = v)
    } else {
        Attrition[, v] = factor(z)
        barplot(table(z), xlab = v)
    }
}

Note that I have converted non-numeric variables to factors. It’s often a good idea to convert categorical variables to factors (especial for the response and when plotting data).

Now, we train a logistic regression model with the “Attrition” variable as the response.

# 1. Partition data to training and validation sets
set.seed(123)

D = Attrition

D$Attrition = ifelse(D$Attrition=="Yes", 1, 0)
n = nrow(D)

ind = sample(1:n, n*0.7)

train = D[ind, ]
valid = D[-ind, ]  # "-" means to use other indices

# 2. Train models with training data
model = glm(Attrition ~ ., 
          data = train, 
          family = binomial
          
)

# 3. Predict the responses in validation set
predicted.prob = predict.glm(model, newdata = valid, type = "response")

predicted.label = ifelse(predicted.prob > 0.5, # in practice, you might pick the best cutoff 
                                               # other than 0.5 so that some performance measure is optimized
                         1, 
                         0
                        )  
  
observed.label = valid$Attrition

# 4. Evaluate the predictive performance
M.valid = confusionMatrix(factor(predicted.label), factor(observed.label), positive = "1")
M.valid
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 360  66
##          1   6   9
##                                         
##                Accuracy : 0.8367        
##                  95% CI : (0.7989, 0.87)
##     No Information Rate : 0.8299        
##     P-Value [Acc > NIR] : 0.3805        
##                                         
##                   Kappa : 0.1519        
##                                         
##  Mcnemar's Test P-Value : 3.57e-12      
##                                         
##             Sensitivity : 0.12000       
##             Specificity : 0.98361       
##          Pos Pred Value : 0.60000       
##          Neg Pred Value : 0.84507       
##              Prevalence : 0.17007       
##          Detection Rate : 0.02041       
##    Detection Prevalence : 0.03401       
##       Balanced Accuracy : 0.55180       
##                                         
##        'Positive' Class : 1             
## 

Ordinal Logistic Regression

The following page discusses how to use R’s “polr” package to perform a proportional odds logistic regression for ordinal data: https://stats.oarc.ucla.edu/r/dae/ordinal-logistic-regression/

To interpret the coefficients in an ordinal logistic regression in r, refer to https://stats.oarc.ucla.edu/r/faq/ologit-coefficients/

Nominal Logistic Regression

The k-Nearest Neighbors (kNN) Classification or Regression

The k-nearest neighbors method can be used for classifying or predicting the value of a new response variable, depending on whether the response is categorical or quantitative.

The Idea

To classify a new record or to predict the value of a quantitative response for a new observation, one idea is to use similar observations (called neighbors) in the training data. For classification, the new observation is classified by voting; that is, if the majority (or more than a certain percent) of the neighbors are of class X, then the new observation is classified as of class X. For prediction, the new observation is predicted by averaging its neighbors (maybe with weights inversely proportional to the distance between neighbors and the new observation). Here is a very nice video about the ideas of k-NN: https://www.youtube.com/watch?v=0p0o5cmgLdE (Watch the first 1:20 minutes)

Although the method is simple and non-parametric (i.e., no parameter to estimate), the time to find the nearest neighbors in a large training set can be prohibitive. One idea to overcome this difficulty is first to reduce the dimension using dimension reduction techniques such as principal component analysis (PCA) in Chapter 4.

Two issues need to be addressed:

  1. How is similarity of observations determined?

  2. How is the number of neighbors chosen?

Each observation that consists of p characteristics (features) can be viewed as a p-vector which can be plotted as a point in a p-dimensional space. Similarity of two observations can be determined using the distance between the corresponding two pints. The closer, the more similar. The most popular distance metric is the Euclidean distance, induced by the 2-norm. Specifically, the 2-norm of a p-dimensional vector \[x = (x_1, x_2, \cdots, x_p)\] is defined as \[||x||_2 = \sqrt{x_1^2+x_2^2+\cdots x_p^2}\]

The Euclidean distance between two observations

\[x=(x_1, x_2, \cdots, x_p)\]

and

\[u=(u_1, u_2, \cdots, u_p)\]

in a p-dimensional space is

\[d(x,u)=\sqrt{(x_1-u_1)^2+(x_2-u_2)^2+\cdots (x_p-u_p)^2}\] Since \[x-u=(x_1-u_1, x_2-u_2, \cdots, x_p-u_p)\]

is a new vector, the distance \(d(x,u)\)

is just the 2-norm of the \(x-u\) vector; that is,

\[d(x,u) = ||x-u||_2\]

In data mining, since variables often have different scales, they should be pre-processed (i.e., standardized or normalized) before computing an Euclidean distance. This is to prevent variables with large values to dominant the distance metric and thus achieve stable results.

To choose the number of neighbors (denoted by \(k\)), we try the values of k between 1 and some number say 20. For a classification problem, \(k\) can be chosen to be odd to avoid ties. The optimal \(k\) is chosen with the lowest error (for instance, lowest classification error for classification and smallest RMSE for prediction). If \(k\) is too small, the method tends to have small bias but with large variation (called over-fitting); if \(k\) is too large, the method tends to have small variation but with large bias (called under-fitting).

In the following example, the binary response variable is whether an animal is a horse (-1) or a mule (1). The two features (i.e., predictors) are Height and Weight.

animals = data.frame(
  Height = c(44.0, 52.1, 57.1, 33.0, 27.8, 27.2, 32.0, 45.1, 56.7, 56.9, 122.1, 123.9, 122.9, 101.1, 128.9, 137.1, 127.0, 103.0, 141.6, 102.4, 93.7),
  Weight = c(126.3, 136.9, 109.2, 148.3, 110.4, 107.8, 128.4, 120.2, 140.2, 139.2, 154.1, 170.8, 183.1, 164.0, 193.6, 181.7, 164.8, 174.6, 185.8, 176.9, 152.4),
  Animal = c(rep(-1,10), rep(1, 10), -1)
)

p = ggplot(animals, aes(x = Height, y = Weight, color = factor(Animal, levels = c(-1,1), labels = c("Horse", "Mule")))) +
  geom_point() +
  labs(color = "Animal")

p

p + 
  geom_point(aes(x = 105, y = 170), color = "red") +
  geom_point(aes(x = 75, y = 140), color = "pink") +
  geom_point(aes(x = 50, y = 130), color = "yellow")

Questions:

  1. If k = 3, What would each of the 3 new observations be classified?

  2. If k = 21, What would each of the 3 new observations be classified?

There are a few packages that do k-NN, including “class”, “FNN”, and “caret”. Here are little more details:

  • When using the class package, use the knn() function to do classification.

    • When using the FNN package, use the knn() function to do classification and knn.reg() to do prediction.

    • When using the caret package, use the knn3() function to do classification and knnreg() to do prediction.

Both packages allow users to provide a k value. The generic function train() from the caret package also can do k-NN and it will choose the optimal k.

A Case Study

A riding-mower manufacturer would like to find a way of classifying families into those likely to purchase a riding mower and those not likely to buy one. The book data is “RidingMowers.csv”.

Since variables are often normalized or standardized before modeling, we use the preProcess() function from the caret package. This function can process the training data and the same procedure can be applied to validation data and any new data.

Read the data into R.

mower.df <- read.csv("/Users/home/documents/zhang/stat415.515.615/DMBA-R-datasets/RidingMowers.csv")
head(mower.df)
##   Income Lot_Size Ownership
## 1   60.0     18.4     Owner
## 2   85.5     16.8     Owner
## 3   64.8     21.6     Owner
## 4   61.5     20.8     Owner
## 5   87.0     23.6     Owner
## 6  110.1     19.2     Owner

Partition data:

set.seed(1)
n = nrow(mower.df)
shuffle.idx <- sample(1:n)
train.idx = shuffle.idx[1:round(0.6*n)]
train.df = mower.df[train.idx, ]
valid.df = mower.df[-train.idx, ]

Normalize data (do this only when necessary):

# use preProcess() from the caret package to normalize Income and Lot_Size in the training data.
# This function creates a recipe for normalizing. 
# The same recipe will be used for normalizing the predictors in validation set as well.
norm.values = preProcess(train.df, method = "range") # "range" indicates normalizing values to [0, 1]
                                                            # The default is standardization (get z-scores)

# Note that the function predict() is the one that does the normalizing job. 
# The way of normalization is a bit strange!
train.norm.df <- predict(norm.values, train.df)  # Use the recipe "norm.values" to normalize the training data
valid.norm.df <- predict(norm.values, valid.df) # Use the same recipe "norm.values" to normalize the validation data

Train a model using the FNN package. We first consider \(k=5\).

# Classify a set of observations using a training set with a 5-Nearest Neighbors classifier
knn.pred <- FNN::knn(train = train.norm.df[, 1:2], # Use the function knn() from package FNN on training data
                  test = valid.norm.df[, 1:2],     # Prediction is on the normalized validation data
                  cl = train.norm.df[, 3],         # cl = true classifications of training set.
                  k = 5,                           # Use 5 neighbors
                  prob = TRUE                      # Proportions of the votes for the winning class are returned
)

#knn.pred  # Show classes, proportions (probs), nn.index (neighbors of each observation in test data), nn.dist (distances)

as.vector(knn.pred)        # Show classes
##  [1] "Owner"    "Owner"    "Owner"    "Owner"    "Owner"    "Owner"   
##  [7] "Owner"    "Owner"    "Nonowner" "Nonowner"
attr(knn.pred, "prob")     # Show proportions of the votes for the winning classes
##  [1] 0.6 0.8 0.8 0.8 0.8 0.6 0.8 0.6 0.8 0.6
#attr(knn.pred, "nn.index") # For each observation in test data, show its neighbors in training data 
#row.names(train.df)[attr(knn.pred,"nn.index")[1,]] # Shows the row names of the first set of 5 neighbors in train data. 

#attr(knn.pred, "nn.dist")  # Show the distances between observations in test data and 
                           # their neighbors in training data

We can print predicted probabilities based on the validation data:

pred.label = as.vector(knn.pred)
prob = attr(knn.pred, "prob")
p = c()
for (i in 1:length(pred.label)){
  p[i] = ifelse(pred.label[i] == "Owner", prob[i], 1 - prob[i])
}
p
##  [1] 0.6 0.8 0.8 0.8 0.8 0.6 0.8 0.6 0.2 0.4

The following code does k-NN for each k between 1 and 14, and report accuracies on the validation set. The optimal value of k is the one that corresponds to the best accuracy.

accuracy = NULL
for(i in 1:14) {
  knn.pred <- FNN::knn(train = train.norm.df[, 1:2], # Use the function knn from package FNN
                  test = valid.norm.df[, 1:2], 
                  cl = train.norm.df[, 3],      # cl = true classifications of training set.
                  k = i,
                  prob = TRUE)   
  accuracy[i] <- caret::confusionMatrix(factor(knn.pred, levels = c("Owner", "Nonowner")), factor(valid.norm.df[, 3], levels = c("Owner", "Nonowner")))$overall[1] 
}


plot(accuracy, type = "l", xlab = "k")

The code shows an optimal value for \(k\). Let’s apply \(k = 6\).

# Normalize the original WHOLE data, since we are classifying a future observation.
mower.norm.df <- mower.df # Initialize
mower.norm.df[, 1:2] <- predict(norm.values, mower.df[, 1:2]) # Apply the same recipe to normalize whole data

## Data for two new households
new.df <- data.frame(Income = c(60, 120), Lot_Size = c(20, 35))
# Apply the recipe to normalize the new data
new.norm.df <- predict(norm.values, new.df) 

knn.pred.new <- FNN::knn(train = mower.norm.df[, 1:2], 
                    test = new.norm.df, 
                    cl = mower.norm.df[, 3], 
                    k = 6, 
                    prob = TRUE)
knn.pred.new
## [1] Owner Owner
## attr(,"prob")
## [1] 0.6666667 1.0000000
## attr(,"nn.index")
##      [,1] [,2] [,3] [,4] [,5] [,6]
## [1,]    4    9   14    1    3   20
## [2,]    5    8   10    3   11   12
## attr(,"nn.dist")
##            [,1]      [,2]      [,3]      [,4]      [,5]      [,6]
## [1,] 0.08557424 0.1167315 0.1251609 0.1666667 0.1779148 0.1839399
## [2,] 1.26228110 1.3983752 1.5200560 1.5687383 1.6231722 1.6423388
## Levels: Owner

Both of the two new observations are predicted to be an “owner”.

Another Case Study

The data “BostonHousing.csv” from the textbook contains information on over 500 census tracts in Boston, where for each tract multiple variables are recorded. The last column (CAT.MEDV) was derived from the median value (MEDV), such that it obtains the value 1 if \(MEDV > 30\) and 0 otherwise. Here, we consider the goal of predicting the CAT.MEDV of a tract, given the information in the first 12 columns. The data will be partitioned into training (60%) and validation (40%) sets. Using the k-NN algorithm, the steps are:

  1. Normalize the predictors.

  2. Using the “class” or “FNN” package to choose the optimal value of \(k\).

  3. Classifiy the CAT.MEDV for two tracts with the following information, using the optimal \(k\).

CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO LSTAT
0.2 0 7 0 0.538 6 62 4.7 4 307 21 10
2.3 8.6 0 0.78 3.5 48 6.2 2 237 15.8 21
housing.df <- read.csv("/Users/home/documents/zhang/stat415.515.615/DMBA-R-datasets/BostonHousing.csv")
head(housing.df)
##      CRIM ZN INDUS CHAS   NOX    RM  AGE    DIS RAD TAX PTRATIO LSTAT MEDV
## 1 0.00632 18  2.31    0 0.538 6.575 65.2 4.0900   1 296    15.3  4.98 24.0
## 2 0.02731  0  7.07    0 0.469 6.421 78.9 4.9671   2 242    17.8  9.14 21.6
## 3 0.02729  0  7.07    0 0.469 7.185 61.1 4.9671   2 242    17.8  4.03 34.7
## 4 0.03237  0  2.18    0 0.458 6.998 45.8 6.0622   3 222    18.7  2.94 33.4
## 5 0.06905  0  2.18    0 0.458 7.147 54.2 6.0622   3 222    18.7  5.33 36.2
## 6 0.02985  0  2.18    0 0.458 6.430 58.7 6.0622   3 222    18.7  5.21 28.7
##   CAT..MEDV
## 1         0
## 2         0
## 3         1
## 4         1
## 5         1
## 6         0
set.seed(1)
n = nrow(housing.df)
shuffle.idx <- sample(1:n)
train.idx = shuffle.idx[1:round(0.6*n)]
train.df = housing.df[train.idx, ]
valid.df = housing.df[-train.idx, ]

# The following uses preProcess() from the caret package to normalize predictors. 
# Categorical predictors will be automatically ignored.
# This function creates a recipe for normalizing. 
# The same recipe is used for normalizing the predictors in validation set as well.
norm.values = caret::preProcess(train.df[, 1:12])

norm.values$mean # Means of the first 12 columns in the training data
##         CRIM           ZN        INDUS         CHAS          NOX           RM 
##   3.71234276  11.06250000  11.23911184   0.07894737   0.55115000   6.27659868 
##          AGE          DIS          RAD          TAX      PTRATIO        LSTAT 
##  67.13815789   3.87707401   9.57894737 409.98355263  18.56348684  12.59203947
norm.values$std  # Standard deviations of the first 12 columns in the training data
##        CRIM          ZN       INDUS        CHAS         NOX          RM 
##   8.5972565  22.3862591   6.9530684   0.2701012   0.1159364   0.6933260 
##         AGE         DIS         RAD         TAX     PTRATIO       LSTAT 
##  28.6754582   2.1411210   8.7555933 170.1155006   2.0946299   7.0478001
# To normalize the predictors of train and validation sets, initialize them first.
train.norm.df = train.df
valid.norm.df = valid.df
# Note the function predict() does the normalizing job. A bit strange!
train.norm.df[, 1:12] <- predict(norm.values, train.df[, 1:12])
valid.norm.df[, 1:12] <- predict(norm.values, valid.df[, 1:12])

# The following code does k-NN with the FNN package for each k between 1 and 20, and report accuracies on validation set. 

accuracy = NULL

for(i in 1:20) {
  knn.pred <- FNN::knn(train = train.norm.df[, 1:12], # Change the package to "FNN" and see what is different
                  test = valid.norm.df[, 1:12], 
                  cl = train.norm.df[, 14],      # cl = true classifications of training set.
                  k = i,
                  prob = TRUE)   
  accuracy[i] <- caret::confusionMatrix(factor(knn.pred, levels = c(1, 0)), factor(valid.norm.df[, 14], levels = c(1, 0)))$overall[1] 
}
accuracy.df = data.frame(k=1:20, accuracy)

accuracy.df
##     k  accuracy
## 1   1 0.9207921
## 2   2 0.9108911
## 3   3 0.9306931
## 4   4 0.9207921
## 5   5 0.9207921
## 6   6 0.9158416
## 7   7 0.9158416
## 8   8 0.8910891
## 9   9 0.9059406
## 10 10 0.9059406
## 11 11 0.9108911
## 12 12 0.9108911
## 13 13 0.9158416
## 14 14 0.9059406
## 15 15 0.9158416
## 16 16 0.9108911
## 17 17 0.9158416
## 18 18 0.9108911
## 19 19 0.9108911
## 20 20 0.9009901
plot(accuracy.df, type = "l")

The code using the package FNN shows that the optimal value of \(k\) is 3. Now, let’s apply this \(k\) to classify 3 new records.

# Normalize the original whole dataset.
housing.norm.df <- housing.df # Initialize
housing.norm.df[, 1:12] <- predict(norm.values, housing.df[, 1:12]) # Apply the recipe to normalize WHOLE data

## new household data: 3 records
new.df <- data.frame(CRIM=c(0.2, 0.5, 0.7),
                     ZN=c(0, 2.3, 4.1),
                     INDUS=c(7, 8.6, 3.9), 
                     CHAS=c(0, 0, 1),
                     NOX=c(0.538, 0.78, 0.66),
                     RM=c(6, 3.5, 3.0),
                     AGE=c(62, 48, 66),
                     DIS=c(4.7, 6.2, 5.1),
                     RAD=c(4,2, 5),
                     TAX=c(307, 237, 200),
                     PTRATIO=c(21, 15.8, 12),
                     LSTAT=c(10,21, 32))
# Apply the recipe "norm.values" obtained in previous code chunk to normalize the new data
new.norm.df <- predict(norm.values, new.df)

# Using the FNN package
knn.pred.new <- FNN::knn(train = housing.norm.df[, 1:12], 
                    test = new.norm.df, 
                    cl = housing.norm.df[, 14], 
                    k = 3,
                    prob = TRUE)

knn.pred.new
## [1] 0 0 0
## attr(,"prob")
## [1] 1 1 1
## attr(,"nn.index")
##      [,1] [,2] [,3]
## [1,]   14   16   20
## [2,]  311  497  148
## [3,]  212  210  143
## attr(,"nn.dist")
##           [,1]      [,2]      [,3]
## [1,] 0.3095234 0.4236047 0.6819967
## [2,] 3.8731306 4.1517231 4.2278079
## [3,] 5.2602378 5.2955815 5.3556020
## Levels: 0

All 3 new records are classified as 0.

Handling Categorical Predictors in k-NN

When nominal categorical predictors occur in a k-NN problem, create dummy variables for each categorical predictor and use ALL dummy variables. The reason of using all dummy variables is as follows:

  1. Multicollinearity is not an issue any more for k-NN.

  2. The Euclidean distance between two observations on a nominal variable should be the same for any pair of two categories. Let’s say there is a political party affiliation variable having 3 categories (R, D, I). The Euclidean distance between any pair of two individuals with distinct affiliations is the same and is the square root of 2 or around 1.414.

A nominal categorical variable with \(m\) categories should correspond to \(m\) dummy variables.

library(fastDummies) # Convert only character columns to dummy variables

crime <- data.frame(city = c("SF", "SF", "NYC"),
    year = c(1990, 2000, 1990),
    crime = 1:3,
    location = c("west", "west", "east"))
crime.dummified = dummy_columns(crime)
crime.dummified
##   city year crime location city_NYC city_SF location_east location_west
## 1   SF 1990     1     west        0       1             0             1
## 2   SF 2000     2     west        0       1             0             1
## 3  NYC 1990     3     east        1       0             1             0
# Select only numeric columns
select_if(crime.dummified, is.numeric)
##   year crime city_NYC city_SF location_east location_west
## 1 1990     1        0       1             0             1
## 2 2000     2        0       1             0             1
## 3 1990     3        1       0             1             0
# Ordinal variables are not character columns, so can't be dummified. The following code won't run.
# dummy_columns(diamonds)

The textbook claims (page 184): For doing k-NN, the package “class” can handle categorical variables automatically, but the package FNN can’t.

What about handling an ordinal categorical variable (say variables on a 5-level Likert scale)? One solution is to score each category of the ordinal variable. For example, look the “cut” variable in the “diamonds” data frame from the ggplot2 package. The variable has 5 categories: Fair, Good, Very Good, Premium, and Ideal. Scores of 1, 2, 3, 4, and 5 can be assigned to these categories in that order, thus making the variable numeric.

new.cut = diamonds$cut

new.cut.val = recode( new.cut, Fair = 1, Good = 2, "Very Good" = 3, Premium = 4, Ideal = 5
)

new.cut.val[1:100]
##   [1] 5 4 2 4 2 3 3 3 1 3 2 5 4 5 4 4 5 2 2 3 2 3 3 3 3 3 4 3 3 3 3 3 3 3 3 2 2
##  [38] 2 3 5 5 5 2 2 2 4 3 2 3 3 3 5 5 4 4 5 4 3 3 2 5 4 5 5 4 5 5 3 4 4 3 3 4 4
##  [75] 2 3 3 3 3 3 3 3 5 5 2 4 4 4 4 4 5 1 5 3 3 2 2 1 3 4

k-NN for a Quantitative Outcome

When using k-NN for predicting the value of a quantitative outcome, use RMSE on validation set for choosing the optimal \(k\).

Let’s do exercise 7.3 on page 185 of the textbook. The data “BostonHousing.csv” contains information on over 500 census tracts in Boston, where for each tract multiple variables are recorded. The last column (CAT.MEDV) was derived from the median value (MEDV), such that it obtains the value 1 if \(MEDV > 30\) and 0 otherwise. Here, we consider the goal of predicting the median value (MEDV) of a tract, given the information in the first 12 columns. The data will be partitioned into training (60%) and validation (40%) sets. Using the k-NN algorithm, the steps are:

  1. Normalize the predictors.

  2. Using the “class” or “FNN” package to choose the optimal value of \(k\).

  3. Predict the MEDV for a tract with the following information, using the optimal \(k\).

CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO LSTAT
0.2 0 7 0 0.538 6 62 4.7 4 307 21 10
housing.df <- read.csv("/Users/home/documents/zhang/stat415.515.615/DMBA-R-datasets/BostonHousing.csv")
head(housing.df)
##      CRIM ZN INDUS CHAS   NOX    RM  AGE    DIS RAD TAX PTRATIO LSTAT MEDV
## 1 0.00632 18  2.31    0 0.538 6.575 65.2 4.0900   1 296    15.3  4.98 24.0
## 2 0.02731  0  7.07    0 0.469 6.421 78.9 4.9671   2 242    17.8  9.14 21.6
## 3 0.02729  0  7.07    0 0.469 7.185 61.1 4.9671   2 242    17.8  4.03 34.7
## 4 0.03237  0  2.18    0 0.458 6.998 45.8 6.0622   3 222    18.7  2.94 33.4
## 5 0.06905  0  2.18    0 0.458 7.147 54.2 6.0622   3 222    18.7  5.33 36.2
## 6 0.02985  0  2.18    0 0.458 6.430 58.7 6.0622   3 222    18.7  5.21 28.7
##   CAT..MEDV
## 1         0
## 2         0
## 3         1
## 4         1
## 5         1
## 6         0
set.seed(111)
n = nrow(housing.df)
shuffle.idx <- sample(1:n)
train.idx = shuffle.idx[1:round(0.6*n)]
train.df = housing.df[train.idx, ]
valid.df = housing.df[-train.idx, ]

# use preProcess() from the caret package to normalize numeric variables.
# This function creates a recipe for normalizing. The same recipe is used for normalizing the 
# predictors in validation set as well.
norm.values = caret::preProcess(train.df[, 1:12])

# To normalize the predictors of train and validation sets, initialize them first.
train.norm.df = train.df
valid.norm.df = valid.df
# Note the function predict() does the normalizing job. A bit strange!
train.norm.df[, 1:12] <- predict(norm.values, train.df[, 1:12])
valid.norm.df[, 1:12] <- predict(norm.values, valid.df[, 1:12])

# The following code does k-NN for each k between 1 and 14, and report accuracies on validation set. 
error = NULL
for(i in 1:20) {
  knn.pred <- FNN::knn.reg(train = train.norm.df[, 1:12], # Change the package to "FNN" and see what is different
                  test = valid.norm.df[, 1:12], 
                  y = train.norm.df[, 13],    # Use y not cl for regression
                  k = i)   
  error[i] <- forecast::accuracy(knn.pred$pred, valid.norm.df[, 13])[2] # Use RMSE
}
## Registered S3 method overwritten by 'quantmod':
##   method            from
##   as.zoo.data.frame zoo
plot(error, type = "l", xlab = "k")

The code using the package FNN shows that the optimal value of \(k\) is 1. Now, let’s apply this \(k\).

# Normalize the predictors in the original whole dataset.
housing.norm.df <- housing.df # Initialize
housing.norm.df[, 1:12] <- predict(norm.values, housing.df[, 1:12]) # Apply the recipe to normalize the predictors in whole data

## new household data: 2 records
new.df <- data.frame(CRIM=c(0.2, 0.5),
                     ZN=c(0, 2.3),
                     INDUS=c(7, 8.6), 
                     CHAS=c(0, 0),
                     NOX=c(0.538, 0.78),
                     RM=c(6, 3.5),
                     AGE=c(62, 48),
                     DIS=c(4.7, 6.2),
                     RAD=c(4,2),
                     TAX=c(307, 237),
                     PTRATIO=c(21, 15.8),
                     LSTAT=c(10,21))
# Apply the recipe "norm.values" obtained in previous code chunk to normalize the new data
new.norm.df <- predict(norm.values, new.df)

# Using the FNN package
knn.pred.new <- FNN::knn.reg(train = housing.norm.df[, 1:12], 
                    test = new.norm.df, 
                    y = housing.norm.df[, 13], 
                    k = 1)

knn.pred.new
## Prediction:
## [1] 20.4 16.1

Why is the validation data error overly optimistic compared to the error rate when applying this k-NN predictor to new data?

If the purpose is to predict MEDV for several thousands of new tracts, what would be the disadvantage of using k-NN prediction? List the operations that the algorithm goes through in order to produce each prediction.

Implementing the K-NN Method with the train() Function from the caret Package

Reference: https://www.youtube.com/watch?v=tSPg-JDAF4M

set.seed(123)
D = iris
D$Species = recode(D$Species, 
                   setosa=0, 
                   versicolor = 0, 
                   virginica = 1) %>% 
  factor(levels = c(0,1), 
         labels = c("virginica", "Not")) # Must have named labels if class Probabilities are wanted


# Partition data into train and validation sets
n = nrow(D)
shuffle.idx <- sample(1:n)
train.idx = shuffle.idx[1:round(0.6*n)]
train.df = D[train.idx, ]
valid.df = D[-train.idx, ]

# No need to do Normalization, since you can choose to do it with the "caret" package!

# To train a model using the caret package, first set some controls: 
trC = caret::trainControl(
                   method = "repeatedcv", 
                   number = 10,        # 10-fild cross-validation
                   repeats = 3,        # Repeat cross-validation 3 times
                   classProbs = TRUE,  # Return the predicted probability for each class
                   summaryFunction = twoClassSummary
                  )

The above code shows,

  • The training data is randomly divided into 10 parts (called folds) of equal size and then use each of the 10 parts as testing data for the model trained on the other 9 folds.

  • Take the average of the 10 accuracy measures obtained to get a cross-validation accuracy measure. This is called 10-fold cross-validation (CV).

  • Repeat the cross-validation 3 times to get 3 cross-validation accuracy measures.

  • Average the 3 numbers to get an overall accuracy measure. Note that 3 repeats of the 10-fold cross-validation is not the same as 30-fold CV.

  • The returned predicted probabilities will be used to construct a two-class summary statistic.

Now, we are ready to fit train the KNN model with the caret package:

knn = caret::train(Species ~ .,
                   data = train.df,
                   method = "knn",
                   preProcess = "range",  # Normalization occurs here! Can be c("center", "scale")
                   tuneGrid=data.frame(k=1:25), # Allows the search of the best k
                   trControl = trC,        # Controls are used here
                   metric = "ROC"          # This relies on class probabilities requested in control.
)

knn
## k-Nearest Neighbors 
## 
## 90 samples
##  4 predictor
##  2 classes: 'virginica', 'Not' 
## 
## Pre-processing: re-scaling to [0, 1] (4) 
## Resampling: Cross-Validated (10 fold, repeated 3 times) 
## Summary of sample sizes: 81, 80, 81, 81, 80, 82, ... 
## Resampling results across tuning parameters:
## 
##   k   ROC        Sens       Spec     
##    1  0.9666667  1.0000000  0.9333333
##    2  0.9811111  0.9700000  0.9222222
##    3  0.9988889  1.0000000  0.9361111
##    4  0.9977778  0.9800000  0.9444444
##    5  0.9977778  0.9888889  0.9333333
##    6  0.9977778  0.9944444  0.9472222
##    7  0.9966667  0.9822222  0.9138889
##    8  0.9988889  0.9822222  0.9333333
##    9  0.9988889  0.9822222  0.9333333
##   10  0.9988889  0.9700000  0.9555556
##   11  0.9988889  0.9700000  0.9666667
##   12  1.0000000  0.9700000  0.9555556
##   13  1.0000000  0.9700000  0.9666667
##   14  0.9988889  0.9700000  0.9666667
##   15  0.9977778  0.9633333  0.9666667
##   16  0.9977778  0.9622222  0.9666667
##   17  0.9977778  0.9633333  0.9555556
##   18  1.0000000  0.9577778  0.9444444
##   19  1.0000000  0.9511111  0.9333333
##   20  1.0000000  0.9522222  0.9666667
##   21  1.0000000  0.9455556  0.9444444
##   22  1.0000000  0.9455556  0.9333333
##   23  0.9975463  0.9511111  0.9444444
##   24  0.9984722  0.9455556  0.9444444
##   25  0.9966667  0.9511111  0.9444444
## 
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was k = 22.
plot(knn)

ROC was used to select the optimal model using the largest value. The final value used for the model was k = 11.

We next generate some useful results:

varImp(knn) # Ranking predictor by importance
## ROC curve variable importance
## 
##              Importance
## Petal.Length     100.00
## Petal.Width       99.53
## Sepal.Length      73.92
## Sepal.Width        0.00
pred = predict(knn, newdata = valid.df)
caret::confusionMatrix(factor(pred), factor(valid.df$Species), positive = "virginica")
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  virginica Not
##   virginica        41   1
##   Not               3  15
##                                          
##                Accuracy : 0.9333         
##                  95% CI : (0.838, 0.9815)
##     No Information Rate : 0.7333         
##     P-Value [Acc > NIR] : 8.636e-05      
##                                          
##                   Kappa : 0.8361         
##                                          
##  Mcnemar's Test P-Value : 0.6171         
##                                          
##             Sensitivity : 0.9318         
##             Specificity : 0.9375         
##          Pos Pred Value : 0.9762         
##          Neg Pred Value : 0.8333         
##              Prevalence : 0.7333         
##          Detection Rate : 0.6833         
##    Detection Prevalence : 0.7000         
##       Balanced Accuracy : 0.9347         
##                                          
##        'Positive' Class : virginica      
## 

The “No Information Rate” result in the output following the confusion matrix means that if all individuals in the validation data are classified as the majority category, the model we would be correct with that rate.

Regression based on k-NN can also be done using the train() function from the caret package.

Classification And Regression Trees (CART, Chapter 9)

The k-NN method can classify a categorical response variable or predict the value of a quantitative variable based on a set of features. This section will introduce an alternative method, called CART, developed by Leo Breiman and his colleagues (1984). CART can be used for both classification (called classification tree) and prediction (called regression tree).

We will answer the following questions:

  • what is a decision tree?

    • How can we use a decision tree to classify or make prediction for a new observation?

    • What is the rationale behind the construction of a decision tree?

    • What are some issues when using a decision tree?

Introduction

We focus on binary trees. A binary tree is a tree that grows two branches (called nodes) out of the root and each branch splits into two new branches (child nodes) , and the same process may continue. When a branch does not grow, it’s called a leaf (or terminal node). The number of terminal nodes is always equal to the total number of splits in constructing the binary tree.

The basic idea behind the binary tree method is to divide observations in a group (called parent node) into two groups (called child nodes in a tree graph) by a predictor selected according to a criterion so that observations in the same group are as similar as possible on the response variable and observations in different groups are as dissimilar as possible on the same response variable. If the selected predictor used to split a parent node is categorical, the split criterion can be an impurity index such as the Gini index or the entropy measure.

Suppose the percents of different categories in a node are \(p_1, p_2, \cdots, p_m\). The Gini index for this node is defined as

\[Gini = \sum_{i=1}^m p_i (1-p_i)\] The Gini index represents the probability that two randomly chosen individuals are of different classes. It measures the diversity of a population.

It’s easy to show that

\[Gini = 1-(p_1^2+p_2^2+\cdots+p_m^2)\]

At a parent node, calculate the Gini index. For each possible split of this node based on a categorical predictor, calculate the Gini index for each of the two child nodes. Weight the indices for the two child nodes by their sizes. Calculate the reduction in impurity from the Gini index of the parent node to the child nodes. Find the largest reduction over on all predictors. The predictor corresponding to this largest reduction is the one used to split the parent node.

The entropy measure is defined as

\[1-[p_1 log_2(p_1)+p_2 log_2(p_2)+\cdots+p_m log_2(p_m)].\]

The procedure using the entropy criterion is similar to that using the Gini index.

A nice tutorial regarding Gini index and Entropy: https://towardsdatascience.com/decision-trees-how-to-draw-them-on-paper-e2597af497f0

If the purpose is to do regression (i.e., predict the value of a quantitative variable), the criterion can be the population variance. Calculate the variance for a parent node and the variance for its two child nodes. Weight the variances for the two children with the relative sizes of the two nodes. Calculate the reduction from the parent node to the two children. Do this for each predictor and pick the one with the largest reduction.

In both classification and regression, if a predictor is quantitative, use each data value of that quantitative variable as a candidate cutoff for splitting the parent node and choose the optimal one that minimizes the total variation in the two child nodes.

For more details, here is a really good tutorial on classification tree and a regression tree: http://web.cs.ucla.edu/~yzsun/classes/2017Fall_CS145/Slides/04Decision_Tree.pdf

Normalization?

No normalization is needed in CART, since measurement units do not matter. But, if you also want to fit another model (which requires normalization) for comparison, it might be wise to do normalization.

Missing values?

Missing values will be handled automatically in CART. If there are missing values on a variable to be used as a candidate for splitting a node, a surrogate variable will be chosen to do the job instead.

Categorical predictors?

For nominal categorical predictors, packages will handle them automatically without creating dummy variables. For ordinal categorical predictors, I would convert them to numeric by setting appropriate numeric values to different categories.

Outliers?

For classification, outliers do not matter much. For regression, they do if the mean of each terminal node is calculated as the predicted value of the quantitative response.

Imbalanced data?

Imbalanced categories of the categorical response variable do have a detrimental impact on the tree’s structure. It can be avoided by either using over-sampling or by using under-sampling depending upon the data.

Ranking predictors/features?

How can we determine feature importance? Obviously, the feature that is chosen to split the root node is important. If a feature has been chosen many times to split nodes, it is also important. If a feature has been chosen as a surrogate variable to split a node, it is also of some importance. Features can be ranked according to importance by some packages for CART.

Here is a very nice tutorial about decision trees: https://towardsdatascience.com/decision-trees-d07e0f420175

We will mainly use packages caret or rpart to do CART. We first use the train() function from the caret package and then introduce the rpart() function from the rpart package. The word “rpart” stands for recursive partitioning for classification and regression.

Let’s use the data frame “ptitanic” from the “rpart” package to create a classification tree. Since there are some missing values in the data, we will need to clean the data by imputing these missing values with the mice package. This package can impute missing values in any column. Check the documentation (by typing ?mice in the console) for methods that handle missing value imputation.

library(rpart.plot) # The data frame "ptitanic" is from the "rpart.plot" package
## Loading required package: rpart
head(ptitanic, n=20)  # The data frame from rpart.plot has missing values 
##    pclass survived    sex     age sibsp parch
## 1     1st survived female 29.0000     0     0
## 2     1st survived   male  0.9167     1     2
## 3     1st     died female  2.0000     1     2
## 4     1st     died   male 30.0000     1     2
## 5     1st     died female 25.0000     1     2
## 6     1st survived   male 48.0000     0     0
## 7     1st survived female 63.0000     1     0
## 8     1st     died   male 39.0000     0     0
## 9     1st survived female 53.0000     2     0
## 10    1st     died   male 71.0000     0     0
## 11    1st     died   male 47.0000     1     0
## 12    1st survived female 18.0000     1     0
## 13    1st survived female 24.0000     0     0
## 14    1st survived female 26.0000     0     0
## 15    1st survived   male 80.0000     0     0
## 16    1st     died   male      NA     0     0
## 17    1st     died   male 24.0000     0     1
## 18    1st survived female 50.0000     0     1
## 19    1st survived female 32.0000     0     0
## 20    1st     died   male 36.0000     0     0
                      # and needs to be cleaned with mice package
summary(ptitanic)     # Missing values on the age variable
##  pclass        survived       sex           age              sibsp       
##  1st:323   died    :809   female:466   Min.   : 0.1667   Min.   :0.0000  
##  2nd:277   survived:500   male  :843   1st Qu.:21.0000   1st Qu.:0.0000  
##  3rd:709                               Median :28.0000   Median :0.0000  
##                                        Mean   :29.8811   Mean   :0.4989  
##                                        3rd Qu.:39.0000   3rd Qu.:1.0000  
##                                        Max.   :80.0000   Max.   :8.0000  
##                                        NA's   :263                       
##      parch      
##  Min.   :0.000  
##  1st Qu.:0.000  
##  Median :0.000  
##  Mean   :0.385  
##  3rd Qu.:0.000  
##  Max.   :9.000  
## 
library(mice)  # A package For imputing missing values
## 
## Attaching package: 'mice'
## The following object is masked from 'package:stats':
## 
##     filter
## The following objects are masked from 'package:base':
## 
##     cbind, rbind
D = complete(mice(ptitanic))  # Impute missing values and rename the data frame as D
## 
##  iter imp variable
##   1   1  age
##   1   2  age
##   1   3  age
##   1   4  age
##   1   5  age
##   2   1  age
##   2   2  age
##   2   3  age
##   2   4  age
##   2   5  age
##   3   1  age
##   3   2  age
##   3   3  age
##   3   4  age
##   3   5  age
##   4   1  age
##   4   2  age
##   4   3  age
##   4   4  age
##   4   5  age
##   5   1  age
##   5   2  age
##   5   3  age
##   5   4  age
##   5   5  age

Next, we partition the imputed data to training and validation sets:

n=nrow(D)
train.index = sample(n, n*0.8)  # We choose to use 80% data for training
train.data = D[train.index, ]
valid.data = D[-train.index, ]

Now, we are ready to train a decision tree model using different methods. We first use the train() function from the caret package.

# Load packages
library(rpart)
library(caret)

fit <- train(factor(survived) ~ .,  # The response variable "survived" is categorical and
                                    # must be converted to factor before creating a 
                                    # classification tree
             data = train.data, 
             method = "rpart", 
             control = rpart.control(maxdepth=3) # Set the maximum depth of any node of the final tree, 
                                                 # with the root node counted as depth 0.
)

We next plot the tree generated from the model.

#Plot trees using the function prp() from package rpart.plot 
rpart.plot(fit$finalModel)  # Plot the tree: each node shows the majority class,                                     # proportion(s) of the classes, and percentage of 

                            # this node among the ROOT node

Another way of plotting:

prp(fit$finalModel, 
    type = 1,       # Label all nodes.
    extra = 1,      # Display extra information at the nodes.
    split.font = 1, # Font for the split labels.
    nn = TRUE,      # Display the node numbers.
    varlen = -15    # Length of variable names in text at the splits 
)                   # The plot shows 1047 passengers are in the 

                    # root node (650 died and 397 survived)

Still another better plot: my favorite!!

library(partykit)
## Loading required package: grid
## Loading required package: libcoin
## Loading required package: mvtnorm
plot(as.party(fit$finalModel))

Without using the train() function from package caret, you can use the rpart() function directly from package “rpart”. The syntax is

rpart(y ~ ., data = )

where the response variable can be categorical or quantitative.

For additional parameters of the function rpart(), check the documentation.

fit <- rpart(survived ~ ., 
             data = ptitanic,
             control = rpart.control(maxdepth=3,    # Set the maximum depth of any node of the final tree, 
                                                    # with the root node counted as depth 0.
                                     minsplit = 20, # The minimum number of observations that must exist                                                             # in a node in order for a split to be attempted.
                                     minbucket = 6, # the minimum number of observations in terminal nodes.
                                     cp = 0.01      # complexity parameter. Any split that does not decrease                                                         # the overall lack of fit by a factor of cp is not attempted.
                                                    # It's used for handling the overfitting issue.
                                    ) 
             )

prp(fit,            # prp = Plot Recursive Partition?
    type = 1,       # Label all nodes.
    extra = 1,      # Display extra information at the nodes.
    split.font = 1, # Font for the split labels.
    nn = TRUE,      # Display the node numbers.
    varlen = -15    # Length of variable names in text at the splits 
)

The top node (called the root) has depth 0. The depth-1 left node is the one with 682 deaths as the majority and the depth-2 second node is the one with 25 survivors as the majority.

Once a decision tree is constructed, different decision rules can be formed. For example,

\[\text{IF} ~(sex = Male)~ \text{AND} ~(age<9.5) ~\text{AND} ~(sibsp<3) ~ \text{THEN} ~Class = "survived"\]

Such a rule tells who would survive.

Let’s use the model to predict the cases in the validation data and create a confusion matrix to show the performance of the model:

# prediction using validation data
pred = predict(fit, 
               newdata = valid.data, 
               type = "class"        # "class" for labels and "prob" for probability
)

pred = as.factor(as.character(pred))
actual = valid.data$survived

# Accuracy measures
confusionMatrix(pred, actual)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction died survived
##   died      141       24
##   survived   25       72
##                                           
##                Accuracy : 0.813           
##                  95% CI : (0.7604, 0.8583)
##     No Information Rate : 0.6336          
##     P-Value [Acc > NIR] : 1.838e-10       
##                                           
##                   Kappa : 0.5981          
##                                           
##  Mcnemar's Test P-Value : 1               
##                                           
##             Sensitivity : 0.8494          
##             Specificity : 0.7500          
##          Pos Pred Value : 0.8545          
##          Neg Pred Value : 0.7423          
##              Prevalence : 0.6336          
##          Detection Rate : 0.5382          
##    Detection Prevalence : 0.6298          
##       Balanced Accuracy : 0.7997          
##                                           
##        'Positive' Class : died            
## 

Both the accuracy and sensitivity measures look good.

If the response variable is quantitative, a regression tree can be constructed. We use the same function rpart() with method = “anova” by default. The impurity measure is the variance (with denominator \(n\) instead of \(n-1\)). We can use RMSE to evaluate the performance of a regression tree.

Here is an example of regression trees:

##                 Price   Country Reliability Mileage  Type
## Acura Integra 4 11950     Japan Much better      NA Small
## Dodge Colt 4     6851     Japan        <NA>      NA Small
## Dodge Omni 4     6995       USA  Much worse      NA Small
## Eagle Summit 4   8895       USA      better      33 Small
## Ford Escort   4  7402       USA       worse      33 Small
## Ford Festiva 4   6319     Korea      better      37 Small
## GEO Metro  3     6695     Japan        <NA>      NA Small
## GEO Prizm  4    10125 Japan/USA Much better      NA Small
## Honda Civic 4    6635 Japan/USA Much better      32 Small
## Hyundai Excel 4  5899     Korea       worse      NA Small

To classify or predict a new record based on a tree that was grown, it is dropped down the tree. When it has dropped all the way down to a terminal node, we can assign its class (or predict the value of the quantitative outcome) simply by taking a vote (or the average) of all the training records that belong to the terminal node.

A tree consists of a root node, some decision nodes, and some terminal nodes. A root node is the starting node. A decision node is a node that will be further split. Terminal nodes are nodes those nodes that will not be further split and are thus called leaves.

More Details on Classification Trees

The program that is used to create a tree is called CART (Classification And Regression Tree). There are other programs, such as ID3 and C4.5. All three are decision tree-based classifiers. Fundamental questions for tree construction are:

  • Which variable will be used to split a node in order to produce purest nodes?

  • How can we define impurity of a node?

  • How missing values will be handled?

  • How can we prune a tree that is too large (overfitted)?

There are some criteria that choose a variable for splitting each node. CART uses Gini or entropy for a classification problem. The Gini index of a node (say \(A\)) based on a variable is defined as

\[I(A) = 1-\sum _{k=1}^{m} p_k^2\]

where \(p_1, p_2, \cdots, p_m\) are the proportions of \(m\) categories in the group represented by the node. The maximum Gini index is \(1-\frac{1}{m}\) (meaning that the node is chaotic) and the minimum is 0 (meaning that the node is pure in terms of the variable considered)

The entropy of a node based on the same variable measures the amount of uncertainty and is defined as

\[entropy(A)=-\sum_{k=1}^m p_k \cdot log_2 (p_k)\] with the unit being bits. Both Gini and entropy are measures of impurity in a node. The range of entropy is from 0 to \(log_2^m\). The higher the measure, the more impurity (or uncertainty).

A quick quiz: If the distribution of a node is each of the following,

  1. \(p_1=0, p_2=1.0\)

  2. \(p_1=0.5, p_2=0.5\)

  3. \(p_1=0.1, p_2=0.9\)

  4. \(p_1=0, p_2=1.0, p_3=0, p_4=0\)

  5. \(p_1=0.25, p_2=0.25, p_3=0.25, p_4 = 0.25\)

  6. \(p_1=0.1, p_2=0.2, p_3=0.3, p_4 = 0.4\)

Calculate the Gini index and entropy for the node.

Solution.

  1. \(I = 1-(0^2+1^2)=0\), \(entropy = -(0\cdot log_2 (0)+1\cdot log_2 (1))=0\)

  2. \(I = 1-(0.1^2+0.2^2+0.3^2+0.4^2)=0.7\), \(entropy = -(0.1\cdot log_2 (0.1)+0.2\cdot log_2 (0.2) + 0.3\cdot log_2 (0.3)+0.4\cdot log_2 (0.4))=?\)

If a node is split into a few nodes (called children), it is called a parent node. The total impurity in all children of a parent node is defined as the weighted average of the impurity of all children, with weights being the percents of records from the parent node. The information gain by splitting the parent node is defined as the difference between the impurity of the parent node and the total impurity in its children.

Let’s consider the two children of the root note. The first child has 3 proportions of \(50/50, 0/50, \text{and} ~0/50\), or 1, 0, and 0, so the impurity of this child note using the Gini index is \(1-(1^2+0^2+0^2)=0\). The second child has 3 proportions of \(0/100, 50/100, \text{and}~50/100\), or 0, 0.5, and 0.5, so the impurity of this child note using the Gini index is \(1-(0^2+0.5^2+0.5^2)=0.5\). The total impurity in the two children is the weighted average \(\frac{50}{150}\cdot 0 + \frac{100}{150}\cdot 0.5 = \frac{1}{3}\). The root node has impurity of \(1-((\frac{1}{3})^2+(\frac{1}{3})^2+(\frac{1}{3})^2)=\frac{2}{3}\).

Now, using the Gini index, the reduction in impurity by splitting the root node with “Petal.Length < 2.5” into two children notes is

\(I(parent)-I(children)=\frac{2}{3} - \frac{1}{3}=\frac{1}{3}\).

The calculation procedure with entropy is similar. The reduction in impurity in this case is also called information gain. Here is the detail:

\[entropy(\text{root})=-[\frac{50}{150}log_2(\frac{50}{150})+\frac{50}{150}log_2(\frac{50}{150})+\frac{50}{150}log_2(\frac{50}{150})]=1.585\] \[entropy(\text{depth-1 left node})=0, ~~\text{since the node is pure.}\] \[entropy(\text{depth-1 right node})= -[\frac{0}{100}log_2(\frac{0}{100})+\frac{50}{150}log_2(\frac{50}{150})+\frac{50}{150}log_2(\frac{50}{150})]= 1.057\] The total entropy in the two children is \[\frac{50}{150}\cdot 0 + \frac{100}{150}\cdot 1.057 = 0.705\] The reduction in impurity (i.e., information gain) is \(1.585-0.705)=0.88.\)

Information gain is biased towards choosing the predictor with a large number of values, which may result in overfitting. A modification of information gain is the gain ratio. The following explains how a gain ratio is calculated.

The amount of uncertainty across all child nodes of a node (called split information or split entropy) is defined using the same formula for impurity of a node. The amount of the impurity across the two children of the root nodes is \(-[\frac{50}{150}log_2(\frac{50}{150})+\frac{100}{150}log_2(\frac{100}{150})]=0.918\). The ratio between the reduction of impurity and split information is \(0.88/0.918=0.96\), which is called a gain ratio. A tutorial is https://www.slideshare.net/marinasantini1/lecture-4-decision-trees-2-entropy-information-gain-gain-ratio-55241087. The gain ratio can be viewed as normalization for information gain. It is implemented in the C4.5 algorithm (the successor to the ID3 algorithm).

When splitting a node, the reduction in impurity or gain ratio is calculated for each of the predictors. The one with the largest reduction or gain ratio will be used to split the node with an appropriate cutoff (for a quantitative predictor) or category (for a categorical predictor).

Avoiding Overfitting

A full-grown tree often overfits data. Overfitting will lead to poor performance on new data.

Criteria for stopping the tree growth before it starts overfitting the data includes

  • Number of splits,

  • Minimum number of records in a terminal node, and

  • Minimum reduction in impurity

we can control the size of the tree (i.e., the number of terminal nodes, which equals the number of splits plus 1) with the complexity parameter (CP, page 223), maximum number of levels (maxdepth), the minimum number of observations (minsplit) that must exist in a node in order for a split to be attempted, and the minimum number of records (minbucket) in a terminal node.

fit = rpart(Species~., data = iris)

printcp(fit) # Print the CP table
## 
## Classification tree:
## rpart(formula = Species ~ ., data = iris)
## 
## Variables actually used in tree construction:
## [1] Petal.Length Petal.Width 
## 
## Root node error: 100/150 = 0.66667
## 
## n= 150 
## 
##     CP nsplit rel error xerror     xstd
## 1 0.50      0      1.00   1.14 0.052307
## 2 0.44      1      0.50   0.65 0.060690
## 3 0.01      2      0.06   0.10 0.030551
plotcp(fit) # The dashed reference line corresponds to the xerror plus one std (explained later)

summary(fit) # Pay attention to variable importance (in terms of percentages)
## Call:
## rpart(formula = Species ~ ., data = iris)
##   n= 150 
## 
##     CP nsplit rel error xerror       xstd
## 1 0.50      0      1.00   1.14 0.05230679
## 2 0.44      1      0.50   0.65 0.06069047
## 3 0.01      2      0.06   0.10 0.03055050
## 
## Variable importance
##  Petal.Width Petal.Length Sepal.Length  Sepal.Width 
##           34           31           21           14 
## 
## Node number 1: 150 observations,    complexity param=0.5
##   predicted class=setosa      expected loss=0.6666667  P(node) =1
##     class counts:    50    50    50
##    probabilities: 0.333 0.333 0.333 
##   left son=2 (50 obs) right son=3 (100 obs)
##   Primary splits:
##       Petal.Length < 2.45 to the left,  improve=50.00000, (0 missing)
##       Petal.Width  < 0.8  to the left,  improve=50.00000, (0 missing)
##       Sepal.Length < 5.45 to the left,  improve=34.16405, (0 missing)
##       Sepal.Width  < 3.35 to the right, improve=19.03851, (0 missing)
##   Surrogate splits:
##       Petal.Width  < 0.8  to the left,  agree=1.000, adj=1.00, (0 split)
##       Sepal.Length < 5.45 to the left,  agree=0.920, adj=0.76, (0 split)
##       Sepal.Width  < 3.35 to the right, agree=0.833, adj=0.50, (0 split)
## 
## Node number 2: 50 observations
##   predicted class=setosa      expected loss=0  P(node) =0.3333333
##     class counts:    50     0     0
##    probabilities: 1.000 0.000 0.000 
## 
## Node number 3: 100 observations,    complexity param=0.44
##   predicted class=versicolor  expected loss=0.5  P(node) =0.6666667
##     class counts:     0    50    50
##    probabilities: 0.000 0.500 0.500 
##   left son=6 (54 obs) right son=7 (46 obs)
##   Primary splits:
##       Petal.Width  < 1.75 to the left,  improve=38.969400, (0 missing)
##       Petal.Length < 4.75 to the left,  improve=37.353540, (0 missing)
##       Sepal.Length < 6.15 to the left,  improve=10.686870, (0 missing)
##       Sepal.Width  < 2.45 to the left,  improve= 3.555556, (0 missing)
##   Surrogate splits:
##       Petal.Length < 4.75 to the left,  agree=0.91, adj=0.804, (0 split)
##       Sepal.Length < 6.15 to the left,  agree=0.73, adj=0.413, (0 split)
##       Sepal.Width  < 2.95 to the left,  agree=0.67, adj=0.283, (0 split)
## 
## Node number 6: 54 observations
##   predicted class=versicolor  expected loss=0.09259259  P(node) =0.36
##     class counts:     0    49     5
##    probabilities: 0.000 0.907 0.093 
## 
## Node number 7: 46 observations
##   predicted class=virginica   expected loss=0.02173913  P(node) =0.3066667
##     class counts:     0     1    45
##    probabilities: 0.000 0.022 0.978
fit$variable.importance # A named numeric vector giving the importance of each variable.
##  Petal.Width Petal.Length Sepal.Length  Sepal.Width 
##     88.96940     81.34496     54.09606     36.01309
# To get the best CP value with the lowest xerror value, do
best.cp = fit$cptable[which.min(fit$cptable[, "xerror"]), "CP"]
best.cp
## [1] 0.01

Based on the CP table, the tree with the lowest cross-validation error (xerror) 0.09 can be constructed with the prune() function. This is shown as follows:

pruned.tree = prune(fit,          # Provide the fitted full-grown tree
                    cp = best.cp  # Update tree with the best cp value. cp = complexity parameter
                   )
pruned.tree
## n= 150 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
##   2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
##   3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
##     6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
##     7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *
library(rpart.plot) # The package is loaded in order to use the prp() function for better plotting the pruned tree
prp(pruned.tree, 
    type =1,
    extra = 1,
    split.font = 1, 
    varlen = -10,
    box.col = ifelse(pruned.tree$frame$var == "<leaf>", 'gray', 'white') # Color leaves gray.
)

A further enhancement, in the interest of model parsimony, is to incorporate the sampling error which might cause the above minimum xerror to vary if we had a different sample. The enhancement uses the estimated standard error (xstd) to prune the tree even further. To do this, we add one standard error (0.029086) to the minimum error (0.09). The result 0.119086 gives an upper limit (R will plot the limit as a dashed line on the plot xerror vs cp). Now, in the CP table, find a smaller tree that has an xerror no greater than 0.119086. Here, we can’t find such a smaller tree, so the tree we go with the best-pruned tree. If a new tree were found, we would fit a tree model using prune() with the new CP value.

A remark: If you find the enhancement step is complicated, then just report the best-pruned tree, since the improvement in parsimony is not much.

Pruning Regression Trees

A best-pruned tree is obtained in the same way as in the classification case.

Toyota = read.csv(file = "/Users/Home/Documents/Zhang/Stat415.515.615/DMBA-R-datasets/ToyotaCorolla.csv")
head(Toyota)
##   Id                                         Model Price Age_08_04 Mfg_Month
## 1  1 TOYOTA Corolla 2.0 D4D HATCHB TERRA 2/3-Doors 13500        23        10
## 2  2 TOYOTA Corolla 2.0 D4D HATCHB TERRA 2/3-Doors 13750        23        10
## 3  3 TOYOTA Corolla 2.0 D4D HATCHB TERRA 2/3-Doors 13950        24         9
## 4  4 TOYOTA Corolla 2.0 D4D HATCHB TERRA 2/3-Doors 14950        26         7
## 5  5   TOYOTA Corolla 2.0 D4D HATCHB SOL 2/3-Doors 13750        30         3
## 6  6   TOYOTA Corolla 2.0 D4D HATCHB SOL 2/3-Doors 12950        32         1
##   Mfg_Year    KM Fuel_Type HP Met_Color  Color Automatic   CC Doors Cylinders
## 1     2002 46986    Diesel 90         1   Blue         0 2000     3         4
## 2     2002 72937    Diesel 90         1 Silver         0 2000     3         4
## 3     2002 41711    Diesel 90         1   Blue         0 2000     3         4
## 4     2002 48000    Diesel 90         0  Black         0 2000     3         4
## 5     2002 38500    Diesel 90         0  Black         0 2000     3         4
## 6     2002 61000    Diesel 90         0  White         0 2000     3         4
##   Gears Quarterly_Tax Weight Mfr_Guarantee BOVAG_Guarantee Guarantee_Period ABS
## 1     5           210   1165             0               1                3   1
## 2     5           210   1165             0               1                3   1
## 3     5           210   1165             1               1                3   1
## 4     5           210   1165             1               1                3   1
## 5     5           210   1170             1               1                3   1
## 6     5           210   1170             0               1                3   1
##   Airbag_1 Airbag_2 Airco Automatic_airco Boardcomputer CD_Player Central_Lock
## 1        1        1     0               0             1         0            1
## 2        1        1     1               0             1         1            1
## 3        1        1     0               0             1         0            0
## 4        1        1     0               0             1         0            0
## 5        1        1     1               0             1         0            1
## 6        1        1     1               0             1         0            1
##   Powered_Windows Power_Steering Radio Mistlamps Sport_Model Backseat_Divider
## 1               1              1     0         0           0                1
## 2               0              1     0         0           0                1
## 3               0              1     0         0           0                1
## 4               0              1     0         0           0                1
## 5               1              1     0         1           0                1
## 6               1              1     0         1           0                1
##   Metallic_Rim Radio_cassette Parking_Assistant Tow_Bar
## 1            0              0                 0       0
## 2            0              0                 0       0
## 3            0              0                 0       0
## 4            0              0                 0       0
## 5            0              0                 0       0
## 6            0              0                 0       0
reg.fit = rpart(Price~., 
                data = Toyota[, c(3, 4, 7:10, 12:14, 17, 18)] # Only use some selected columns
)

printcp(reg.fit) # Print the CP table
## 
## Regression tree:
## rpart(formula = Price ~ ., data = Toyota[, c(3, 4, 7:10, 12:14, 
##     17, 18)])
## 
## Variables actually used in tree construction:
## [1] Age_08_04 Weight   
## 
## Root node error: 1.8877e+10/1436 = 13145711
## 
## n= 1436 
## 
##         CP nsplit rel error  xerror     xstd
## 1 0.657673      0   1.00000 1.00169 0.063226
## 2 0.112705      1   0.34233 0.35169 0.021826
## 3 0.032403      2   0.22962 0.24107 0.020150
## 4 0.021944      3   0.19722 0.22725 0.015923
## 5 0.015921      4   0.17528 0.19662 0.013169
## 6 0.013126      5   0.15935 0.19304 0.013501
## 7 0.010000      6   0.14623 0.18557 0.013076
plotcp(reg.fit)

# To get the best CP value with the lowest xerror value, do
best.cp = reg.fit$cptable[which.min(reg.fit$cptable[, "xerror"]), "CP"]
best.cp
## [1] 0.01
pruned.tree = prune(reg.fit, cp = best.cp)
pruned.tree
## n= 1436 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 1436 18877240000 10730.820  
##    2) Age_08_04>=32.5 1250  4682009000  9596.602  
##      4) Age_08_04>=56.5 833  1353802000  8673.539  
##        8) Age_08_04>=68.5 392   413177400  7925.569 *
##        9) Age_08_04< 68.5 441   526376400  9338.401 *
##      5) Age_08_04< 56.5 417  1200649000 11440.510  
##       10) Age_08_04>=44.5 225   520442800 10728.440 *
##       11) Age_08_04< 44.5 192   432425900 12274.970 *
##    3) Age_08_04< 32.5 186  1780187000 18353.290  
##      6) Weight< 1277.5 179  1077456000 17994.680  
##       12) Weight< 1122.5 100   281310300 16842.980 *
##       13) Weight>=1122.5 79   495606400 19452.520 *
##      7) Weight>=1277.5 7    91051840 27523.570 *
prp(pruned.tree, 
    type =1,
    extra = 1,
    split.font = 1, 
    varlen = -15,
    box.col=ifelse(pruned.tree$frame$var == "<leaf>", 'lightblue', 'yellow') # Color leaves lightblue.
)

Conditional inference trees: the partykit Package

Read textbook page 222.

library(partykit)

# For classification
ct1 <- ctree(Species ~ .,data = iris)

plot(ct1) # A nice plot

# For regression
ct2 <- ctree(Sepal.Length ~ .,data = iris)

plot(ct2) 

Classification Trees with More than Two Child Nodes

The package RWeka which implements the C4.5 algorithm can create more than two child nodes based on a categorical predictor of more than 2 categories. We consider 3 predictors for ease of demo using the Titanic data.

library(RWeka)
fit <- J48(survived~pclass+sex+parch, data=ptitanic) # Not binary tree

# summarize the fit
summary(fit)
## 
## === Summary ===
## 
## Correctly Classified Instances        1039               79.3736 %
## Incorrectly Classified Instances       270               20.6264 %
## Kappa statistic                          0.5494
## Mean absolute error                      0.302 
## Root mean squared error                  0.3886
## Relative absolute error                 63.9622 %
## Root relative squared error             79.98   %
## Total Number of Instances             1309     
## 
## === Confusion Matrix ===
## 
##    a   b   <-- classified as
##  714  95 |   a = died
##  175 325 |   b = survived
# Plot the decision tree
plot(fit)

Advantages and Disadvantages of Trees

Advantages:

  • A Decision tree model is very intuitive and easy to interpret.

  • Normalization of data is not required.

  • Missing values can be handled automatically without imputation.

  • Tree models are robust to outliers (Since the choice of a split depends on the ordering of values and not on the absolute magnitudes of these values)

  • A decision tree model is useful for variable selection (Since most important predictors are usually showing up at the top of the tree).

  • The tree method is nonlinear (no linear relationship between the outcome variable and predictors is assumed) and non-parametric (No parameter to estimate)

  • Multi-class classification can be performed.

Disadvantages:

  • Data must be linearly separable.

  • A small change in the data can cause a large change in the structure of the decision tree causing instability.

  • Training a tree model with a large data set often is time-consuming.

  • The choice of the root node is too important so that the structure might not be stable with new data.

Bagging

Improving Predictive Performance: Random Forests and Boosted Trees

Results from multiple trees can be combined to improve predictive power. This is the ensemble approach. An ensemble combines multiple supervised models into a “super-model.” To create an ensemble, we can use bagging (short for “bootstrap aggregating”) or boosting.

Bagging comprises two steps:

  1. Generate multiple random samples by sampling with replacement from the original data. This is called bootstrap sampling.

  2. Running an algorithm on each sample and producing scores.

Bagging improves the performance stability of a model and helps avoid overfitting by separately modeling different data samples and then combining the results. It is a useful ensemble method for trees and neural networks (to be covered later).

Random forests are a special case of bagging. The steps in random forests are:

  • Draw bootstrap samples (resampling with replacement) from the original data.

  • Using a random subset of \(m\) predictors at each stage, fit a classification or regression tree to each bootstrap sample (and thus obtain a “forest”).

  • Combine the results from the individual trees to obtain an improved result. Use voting for classification and averaging for prediction.

bank.df = read.csv(file = "/Users/Home/Documents/Zhang/Stat415.515.615/DMBA-R-datasets/UniversalBank.csv")
head(bank.df)
##   ID Age Experience Income ZIP.Code Family CCAvg Education Mortgage
## 1  1  25          1     49    91107      4   1.6         1        0
## 2  2  45         19     34    90089      3   1.5         1        0
## 3  3  39         15     11    94720      1   1.0         1        0
## 4  4  35          9    100    94112      1   2.7         2        0
## 5  5  35          8     45    91330      4   1.0         2        0
## 6  6  37         13     29    92121      4   0.4         2      155
##   Personal.Loan Securities.Account CD.Account Online CreditCard
## 1             0                  1          0      0          0
## 2             0                  1          0      0          0
## 3             0                  0          0      0          0
## 4             0                  0          0      0          0
## 5             0                  0          0      0          1
## 6             0                  0          0      1          0
n = nrow(bank.df) # Number of rows in iris
shuffled.idx = sample(1:n) # Permute the rows
train.idx = shuffled.idx[1:round(0.6*n)]
train.df = bank.df[train.idx, -c(1, 5)]
valid.df = bank.df[-train.idx, -c(1, 5)]

library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
## 
##     combine
## The following object is masked from 'package:ggplot2':
## 
##     margin
rf = randomForest(factor(Personal.Loan) ~ .,    # Need to make the categorical response a factor!
                  data = train.df,
                  ntree = 20, # Number of trees to grow. 
                  mtry = 4, # Number of variables randomly sampled as candidates at each split. 
                  nodesize = 25, # Minimum size of terminal nodes.
                  maxnodes = NULL, # Maximum number of terminal nodes trees in the forest can have. 
                                   # If not given, trees are grown to the maximum possible (subject to 
                                   # limits by "nodesize"). 
                  importance = TRUE, # Should importance of predictors be assessed?
                  proximity = TRUE # Should proximity measure among the rows be calculated?
                                   # The default is FALSE.
                                   # The measures can be used for creating a MDS plot for                                      # classification with 3+ categories using function MDSplot()
     )

rf  # Attention to OOB (Out Of Bag) error rate (similar to 1 minus overall accuracy)
## 
## Call:
##  randomForest(formula = factor(Personal.Loan) ~ ., data = train.df,      ntree = 20, mtry = 4, nodesize = 25, maxnodes = NULL, importance = TRUE,      proximity = TRUE) 
##                Type of random forest: classification
##                      Number of trees: 20
## No. of variables tried at each split: 4
## 
##         OOB estimate of  error rate: 1.7%
## Confusion matrix:
##      0   1 class.error
## 0 2688  18 0.006651885
## 1   33 261 0.112244898
    # Out of Bag = Not in bootstrap sample

plot(rf)  # Plots of OOB errors (black) along with class errors

          # The x-title "trees" should be the "number of trees"

# Variable importance plot
varImpPlot(rf, 
           sort=TRUE, # Variables are sorted in decreasing order of importance
           n.var = 11, # How many variables to show?
           main = "Variable Importance Plot")

# How many times each predictor is used
varUsed(rf) 
##  [1]  59  57 135  59  99  63  43   2  32  19  10
# Extract a single tree from a forest
getTree(rf, 
        k = 5,   # Extract the fifth tree
        labelVar = TRUE # Use better labels for splitting variables & predicted class?
        )
##    left daughter right daughter  split var split point status prediction
## 1              2              3     Income      101.50      1       <NA>
## 2              4              5      CCAvg        3.05      1       <NA>
## 3              6              7  Education        1.50      1       <NA>
## 4              8              9      CCAvg        2.95      1       <NA>
## 5             10             11 CD.Account        0.50      1       <NA>
## 6             12             13   Mortgage      197.00      1       <NA>
## 7             14             15        Age       39.50      1       <NA>
## 8              0              0       <NA>        0.00     -1          0
## 9              0              0       <NA>        0.00     -1          0
## 10            16             17      CCAvg        3.55      1       <NA>
## 11             0              0       <NA>        0.00     -1          1
## 12            18             19     Income      106.50      1       <NA>
## 13            20             21 Experience       25.50      1       <NA>
## 14            22             23 Experience       13.50      1       <NA>
## 15            24             25        Age       64.50      1       <NA>
## 16             0              0       <NA>        0.00     -1          0
## 17            26             27     Income       94.00      1       <NA>
## 18             0              0       <NA>        0.00     -1          0
## 19            28             29      CCAvg        4.05      1       <NA>
## 20            30             31     Income      119.50      1       <NA>
## 21             0              0       <NA>        0.00     -1          0
## 22            32             33 Experience        7.50      1       <NA>
## 23             0              0       <NA>        0.00     -1          0
## 24            34             35      CCAvg        3.90      1       <NA>
## 25             0              0       <NA>        0.00     -1          1
## 26            36             37      CCAvg        3.65      1       <NA>
## 27             0              0       <NA>        0.00     -1          0
## 28            38             39     Family        2.50      1       <NA>
## 29            40             41      CCAvg        6.20      1       <NA>
## 30             0              0       <NA>        0.00     -1          0
## 31            42             43     Family        2.50      1       <NA>
## 32            44             45     Income      116.50      1       <NA>
## 33             0              0       <NA>        0.00     -1          1
## 34            46             47     Income      118.00      1       <NA>
## 35             0              0       <NA>        0.00     -1          1
## 36             0              0       <NA>        0.00     -1          0
## 37             0              0       <NA>        0.00     -1          0
## 38             0              0       <NA>        0.00     -1          0
## 39             0              0       <NA>        0.00     -1          1
## 40            48             49     Family        2.50      1       <NA>
## 41            50             51     Family        2.50      1       <NA>
## 42             0              0       <NA>        0.00     -1          0
## 43             0              0       <NA>        0.00     -1          1
## 44             0              0       <NA>        0.00     -1          0
## 45             0              0       <NA>        0.00     -1          1
## 46             0              0       <NA>        0.00     -1          0
## 47             0              0       <NA>        0.00     -1          1
## 48             0              0       <NA>        0.00     -1          0
## 49             0              0       <NA>        0.00     -1          1
## 50             0              0       <NA>        0.00     -1          0
## 51             0              0       <NA>        0.00     -1          1
# Confusion matrix
rf.pred = predict(rf, valid.df)

caret::confusionMatrix(factor(rf.pred), factor(valid.df$Personal.Loan))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 1805   25
##          1    9  161
##                                           
##                Accuracy : 0.983           
##                  95% CI : (0.9763, 0.9882)
##     No Information Rate : 0.907           
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.8952          
##                                           
##  Mcnemar's Test P-Value : 0.0101          
##                                           
##             Sensitivity : 0.9950          
##             Specificity : 0.8656          
##          Pos Pred Value : 0.9863          
##          Neg Pred Value : 0.9471          
##              Prevalence : 0.9070          
##          Detection Rate : 0.9025          
##    Detection Prevalence : 0.9150          
##       Balanced Accuracy : 0.9303          
##                                           
##        'Positive' Class : 0               
## 

What is the optimal value of “mtry” to use in the randomForest() function? We can get it by doing the following:

tuned.rf = tuneRF(x = train.df[, -8], # Must contain data only for predictors
       y = factor(train.df[, 8]),  # For classification, the response must be a factor.
       mtryStart = 6,  # starting value of mtry; default is the same as in randomForest()
       ntreeTry = 50, # number of trees used at the tuning step
       stepFactor = 1.5, # at each iteration, mtry is inflated (or deflated) by a factor of this value
                       # The output shows two directions: "Searching left ..." and Searching right ..."
                       # Rounding is used
       improve = 0.05, # the (relative) improvement in OOB error must be by this much for the search to continue
       trace = TRUE,   # whether to print the progress of the search
       plot = TRUE,    # whether to plot the OOB error as function of mtry
       doBest = TRUE   # whether to run a forest using the optimal mtry found
      )
## mtry = 6  OOB error = 1.47% 
## Searching left ...
## mtry = 4     OOB error = 1.37% 
## 0.06818182 0.05 
## mtry = 3     OOB error = 1.2% 
## 0.1219512 0.05 
## mtry = 2     OOB error = 2.1% 
## -0.75 0.05 
## Searching right ...
## mtry = 9     OOB error = 1.3% 
## -0.08333333 0.05

tuned.rf
## 
## Call:
##  randomForest(x = x, y = y, mtry = res[which.min(res[, 2]), 1]) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 3
## 
##         OOB estimate of  error rate: 1.4%
## Confusion matrix:
##      0   1 class.error
## 0 2700   6 0.002217295
## 1   36 258 0.122448980

We plot the final tree:

plot(tuned.rf)

The top curve shows error rate for the class “1”, the bottom for class “0”, and the middle for both.

The confusion matrix is given below:

pred = predict(tuned.rf, valid.df)

caret::confusionMatrix(factor(pred), factor(valid.df$Personal.Loan))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 1812   25
##          1    2  161
##                                           
##                Accuracy : 0.9865          
##                  95% CI : (0.9804, 0.9911)
##     No Information Rate : 0.907           
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.9153          
##                                           
##  Mcnemar's Test P-Value : 2.297e-05       
##                                           
##             Sensitivity : 0.9989          
##             Specificity : 0.8656          
##          Pos Pred Value : 0.9864          
##          Neg Pred Value : 0.9877          
##              Prevalence : 0.9070          
##          Detection Rate : 0.9060          
##    Detection Prevalence : 0.9185          
##       Balanced Accuracy : 0.9322          
##                                           
##        'Positive' Class : 0               
## 

We can plot variable importance using a barchart

Heights = as.numeric(tuned.rf$importance) # Heights of bars
Names = row.names(tuned.rf$importance) # Names of bars
names(Heights) = Names # Convert Heights to a vector with each value named

par(mar = c(5, 6, 3, 2))
barplot(sort(Heights),  # Sort bars to create a Pareto chart
        horiz = TRUE,
        main = "Importance of Predictors",
        las = 2
       )

You might be wondering what is the tree size (number of nodes) of each of the trees:

treesize(tuned.rf)
##   [1]  68  71  70 103  61  93  63  60  58  83  75  80 102  94 102 129  67  74
##  [19]  63  62  70 112  78  68  54  92  84  84  69 110  86  82  59  71  86 107
##  [37]  83 116  75  93  84  62  95  98  86  51  79  86  77  87  77  55 113  93
##  [55]  73  83  70  79 100  79  78  81  79 120 114 133  89 101  81  66  56  65
##  [73]  83 120  87 120 116  95  93  90  96 118  94  57  65  62  95  61 111  91
##  [91] 109  91  87  63 132  85  91  85 133  73  89  91  79 119  96  99 103  70
## [109]  97  81  87  96  58 105  68  94  83  89  83  77  69  98 122  89  91 121
## [127] 125 109  72  85  98  76  90  73 124  97 112  82  78  82  91  96 101  73
## [145] 120  94  80  94  81  95 103  87  68  86  59 133 101  73  69  63  80  76
## [163] 104  68  86  95 111  86  83 143  55  94  90  74  94  99 104  66  81 100
## [181]  59  90  71  77  65 121 142  63  47  76  76  96  95  82  74  65  72 117
## [199] 130  62  87  67  57  78  75 110  82  71  82  66  91 142  77 106 101  87
## [217] 111  92  72  84  65  97 134  59  82 118  93  92  55  74  78  67  70  85
## [235]  51  81  87  83  76  78  80  80  89  87 150 100  80 105  85 104  91 108
## [253]  86  75 106  76 151  61  63  89 104  70 103 116  96  87  50 107  81  67
## [271]  98  68 102  82 108  72  96  60  54 121  93  69  71  71  99  87  96  71
## [289]  76 102  58 115  90  86 123  90  68  97  86 118  87  70 106  76  90  71
## [307] 140  55 132 118 105 115  73  73  93 116  91  88  72  86 118  84  87  93
## [325] 100  68  55 121 140 117  67  77  97  87  80  95  84  75  89  61  66  84
## [343] 122  67  80  99  79  83 101  92  98  87  81  94  60  80  93  89  91  52
## [361]  65  62  66  80  76  88 126  88 102 104  77  86  61  51  82  94 104 136
## [379]  71  82  66  89  65  79  98 108  61 115  84 119  80  97  76  67  72  79
## [397]  68  48  61 116  73  85  61  91  86  65 108  71  77  64  65 108  86  77
## [415]  70 114  47  86  91  81  60  72  54 121  61  72 111  83  98  59 119 110
## [433]  99  93  78  48 102 128  96 134  38 104  95  88  88  94  81  96 102  83
## [451]  94  76  69 103 102  63  65  49  74 144 105  81 115 126  63  98 107 113
## [469] 102  88 103  57 111  92  61  76  75  65  75  75  79  86  82  97  94  76
## [487]  67 108  66 103  85  87  89  93  95 101  52  85 123  67
hist(treesize(tuned.rf), 
     xlab = "Size",
     main = "Sizes of Trees Created by the Random Forest Algorithm") 

Boosting

Boosting comprises the following steps:

  1. Fit a single tree. (mis-classified records will be remembered by R)

  2. Draw a sample that gives higher selection probabilities to mis-classified records.

  3. Fit a tree to the new sample.

  4. Repeat steps 2 and 3 multiple times.

  5. Use weighted voting to classify records, with heavier weight for later trees.

The package for boosting is “adabag”.

library(adabag)
## Loading required package: foreach
## Loading required package: doParallel
## Loading required package: iterators
## Loading required package: parallel
# Must convert the outcome variable to factor first!
train.df$Personal.Loan <- as.factor(train.df$Personal.Loan)
valid.df$Personal.Loan <- as.factor(valid.df$Personal.Loan)

# Create boosted trees. This will take a while.
boost = boosting(Personal.Loan ~ ., 
                         data = train.df,
                         mfinal=10      # the number of iterations (or trees) to use.
                        )

boost$importance # variable relative importance (sum = 100)
##                Age              CCAvg         CD.Account         CreditCard 
##          9.2509983         24.3995276          1.2429167          0.5909136 
##          Education         Experience             Family             Income 
##         12.9284963          2.5416053         10.0155023         35.2233251 
##           Mortgage             Online Securities.Account 
##          1.5330956          1.6463495          0.6272695

We can plot the predictors according to their importance:

barplot(boost$imp[order(boost$imp, decreasing = TRUE)],
        ylim = c(0, max(boost$importance)*1.1), main = "Variables Relative Importance",
        col = "lightblue"
       )

head(boost$prob)  # The proportion of votes in the final ensemble.
##           [,1]      [,2]
## [1,] 1.0000000 0.0000000
## [2,] 1.0000000 0.0000000
## [3,] 1.0000000 0.0000000
## [4,] 1.0000000 0.0000000
## [5,] 1.0000000 0.0000000
## [6,] 0.1034243 0.8965757
head(boost$class) # Some of the class predicted by the ensemble classifier.
## [1] "0" "0" "0" "0" "0" "1"
boost$weights # A vector with the weighting of the trees of all iterations.
##  [1] 2.070215 1.853392 1.698767 1.573622 1.749838 1.549771 1.705153 1.555282
##  [9] 1.327552 1.403368

The following code calculates the error evolution of a boosting or Bagging classifier for a data frame as the ensemble size grows:

evol.train = errorevol(boost, newdata=train.df) # Error evolution for training data
evol.valid = errorevol(boost, newdata=valid.df) # Error evolution for validation data
plot.errorevol(evol.valid, evol.train) # Plot the errors based on train and validation data

The above plot shows that a few iterations (say 10) would be enough.

The following shows the confusion matrix which gives some performance measures of the trained model based on the validation data.

pred = predict(boost, valid.df)

caret::confusionMatrix(factor(pred$class), factor(valid.df$Personal.Loan))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 1812   21
##          1    2  165
##                                           
##                Accuracy : 0.9885          
##                  95% CI : (0.9828, 0.9927)
##     No Information Rate : 0.907           
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.9286          
##                                           
##  Mcnemar's Test P-Value : 0.0001746       
##                                           
##             Sensitivity : 0.9989          
##             Specificity : 0.8871          
##          Pos Pred Value : 0.9885          
##          Neg Pred Value : 0.9880          
##              Prevalence : 0.9070          
##          Detection Rate : 0.9060          
##    Detection Prevalence : 0.9165          
##       Balanced Accuracy : 0.9430          
##                                           
##        'Positive' Class : 0               
## 

Neural Nets (Chapter 11)

For those who are planning to participate in data competition, here are student projects done by the textbook author: https://www.galitshmueli.com/student-projects

Neural networks, also called artificial neural networks (ANN), are models for classification or prediction, in addition to k-NN’s and decision trees. Check out the images of many neural nets via https://www.google.com/.

11.1 The Structure of a Neural Network

A neural network has an input layer, zero, one, or more hidden layers, and an output layer. A network without hidden layer is called a perceptron. A network with one or more hidden layers is called a multilayer perceptron (MLP). The input layer is composed of input neurons (or just called nodes). The inputs are the values of predictors (called features) of a single observation (such as an individual, an image, or a document). Each hidden layer has a few neurons. The output layer can be one or more neurons. A neuron can be activated by other neurons to which it is connected. Each directed line connecting two neurons is called a synapse. The strength of connection between two neurons is quantified by the weight.

For the prediction of a single numerical variable, one output node is needed. For a classification problem, the number of output nodes equals the number of classes. A tutorial: https://ujjwalkarn.me/2016/08/09/quick-intro-neural-networks/

A very nice tutorial for detailed computation of outputs of a neural network: https://stevenmiller888.github.io/mind-how-to-build-a-neural-network/.

How to compute the output of a neural network?

Some computation details: Both inputs are 1, so the third node in the hidden layer is

\[1(0.3) + (1)(0.5) = 0.8\] before activation with the logistic function

\[f(x) = \frac{1}{1+e^{-x}}\]

After activation, the output from this node is

\[f(0.8) = \frac{1}{1+e^{-0.8}} = 0.69\] Other 4 numbers in the hidden layer are obtained in the same way.

Finally, use the numbers 0.73, 0.79, and 0.69 as inputs to get 1.2. After activation with the same function, it becomes 0.77.

Definitely check the tutorial out!

11.2 Steps in Training a Neural Network

A feed-forward neural network is an artificial neural network where the connections (called edges) between two consecutive layers do not form a cycle. It is a directed graph. In a dense neural network, each neuron (say \(i\)) in a layer is connected to each neuron (say \(j\)) in the next layer. Each edge is associated with a weight (denoted by \(w_{ij}\)) describing the contribution of \(i\) to \(j\). Techniquely, a dummy node is added in the input layer and each hidden layer. All dummy nodes take the same value of 1. The textbook uses \(\theta_k\) to denote the contribution of the \(k\)th dummy node. These \(\theta_k\)’s are called biases.

To train a neural network, follow the following steps:

  • Pick a neural network architecture. What is the number of input nodes (equal to the number of features)? What is the number of hidden layers (the most common is 1 or 2)? What is the number of nodes in each of the hidden layers? This is a hard issue. Some suggest that the number of nodes in each hidden layer be roughly equal to the mean of the nodes in the input and output layers. What is the number of output nodes? The number of output nodes either equals 1 (for regression) or equals the number of classes (for classification).

  • Dummify nominal features, score ordinal features using values between 0 and 1, and normalize numerical features to [0, 1].

  • Choose initial weights: The weights are randomly chosen to be very close to 0, such as between -0.1 and 0.1.

  • Choose an activation function. Activation functions are what make a neuralnet work model nonlinear. Examples of activation functions include the identity function, logistic (or sigmoid) function, hyperbolic tangent function (or tanh), and ReLU (Rectified Linear Unit).

  • Choose an error or cost function for optimizing weights. The R package “neuralnet” provides two error functions through the argument “err.fct”. One is the “sum of squared error (sse)” function, defined as

    \[E_{sse}=\sum_{i=1}^{n}\sum_{j=1}^{c}\frac{1}{2}(y_{ij}-\hat{y}_{ij})^2\] This error function can be used for both regression (with quantitative response) and classification (with nominal response).

For classification problem, the error function can be chosen to be the “cross-entropy (ce)” function, defined as

\[E_{ce}=-\sum_{i=1}^{n}\sum_{j=1}^{c}y_{ij}\cdot log(\hat{y}_{ij}),\]

where \(n\) is the number of observations, \(c\) is the number of output nodes, \(y_{ij}\) represents the observed value of the \(i\)th observation for the \(j\)th output node, and \(\hat{y}_{ij}\) represents the corresponding predicted propensities, (usually) based on the softmax transformation. A softmax transformation is a transformation that transform a vector \(x\) to \(\frac{e^x}{\sum e^x}\). For example, the vector \((0.4, 0.9, 1.5)\) can be transformed to

\[(\frac{e^{0.4}}{e^{0.4}+e^{0.9}+e^{1.5}}, \frac{e^{0.9}}{e^{0.4}+e^{0.9}+e^{1.5}}, \frac{e^{1.5}}{e^{0.4}+e^{0.9}+e^{1.5}})\]

or \((0.1769, 0.2917, 0.5314)\).

  • Use gradient descent method and the back propagation (BP) algorithm (1986, by Geoffrey Hinton) to update weights. The best tutorial: http://home.agh.edu.pl/~vlsi/AI/backp_t_en/backprop.html

  • Apply stopping rules to avoid overfitting (large variance). R uses stepmax = 100000 by default and threshold = 0.01 by default for the partial derivative of the error function with respect to weights. One can try different stepmax values and plot errors based on train and validation sets. The curve based on training data should decrease but the curve based on validation data should decrease and then increase. The best stepmax value correspond to the turning point.

11.3 Training Neural Networks

Once a neural network architecture is selected, a learning procedure, called back-propagation, is used to repeatedly adjust the weights of the connections in the network so as to minimize a measure (called an error, loss, or cost function) of the difference between the actual output vector of the net and the desired output vector. This article, https://www.nature.com/articles/323533a0, describes the back-propagation procedure.

For the first record in the training data, the model starts with a set of initial weights usually randomly chosen by software. When the output node(or nodes) produces(or produce) an output, it is compared with the actual outcome value for the first record. The error is back-propagated to each hidden nodes, and the input weights and all other weights then get updated in the same way using the gradient descent method. This is demonstrated in the wonderful tutorial: http://home.agh.edu.pl/~vlsi/AI/backp_t_en/backprop.html. This series of videos may also help: https://www.youtube.com/watch?v=5u0jaA3qAGk

R packages such as “nnet” and “neuralnet” can be used to fit neural networks. The book uses “neuralnet”. Please use the newest version of “neuralnet”. On the R console, issue the command:

devtools::install_github(“bips-hb/neuralnet”)

A neuron has inputs and an output. The output is obtained by applying an activation function to its net input (weighted sum of inputs). For a list of commonly used activation functions, refer to https://en.wikipedia.org/wiki/Activation_function. The activation functions R uses are

  • The “logistic” function:

    \[logistic(x)=\frac{1}{1+e^{-x}}\]

  • The “tanh” function:

    \[tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}\]

  • The “ReLU” function:

    \[ReLU(x) = max(0, x)\]

A smoothed version of ReLU is the softplus function, which is defines as \[softplus(x)=log(1+e^x)\].

Here are the plots of those functions:

We will use the function neuralnet() from the "neuralnet package to train a neural network.

Let’s follow the textbook example on pages 274-276. The data are

df = data.frame(Salt = c(0.9, 0.1, 0.4, 0.5, 0.5, 0.8),
                Fat = c(0.2, 0.1, 0.2, 0.2, 0.4, 0.3),
                Acceptance = c("like", "dislike", "dislike", "dislike", "like", "like")
               )
df
##   Salt Fat Acceptance
## 1  0.9 0.2       like
## 2  0.1 0.1    dislike
## 3  0.4 0.2    dislike
## 4  0.5 0.2    dislike
## 5  0.5 0.4       like
## 6  0.8 0.3       like

We first create a neural network using the default setting:

## Replicate the book example on pages 274-280. The book has details for computation.
nn <- neuralnet(
         Acceptance ~ Salt + Fat,   
         data = df,
         linear.output = FALSE  # For classification
       )

nn
## $call
## neuralnet(formula = Acceptance ~ Salt + Fat, data = df, linear.output = FALSE)
## 
## $response
##   dislike  like
## 1   FALSE  TRUE
## 2    TRUE FALSE
## 3    TRUE FALSE
## 4    TRUE FALSE
## 5   FALSE  TRUE
## 6   FALSE  TRUE
## 
## $covariate
##      Salt Fat
## [1,]  0.9 0.2
## [2,]  0.1 0.1
## [3,]  0.4 0.2
## [4,]  0.5 0.2
## [5,]  0.5 0.4
## [6,]  0.8 0.3
## 
## $model.list
## $model.list$response
## [1] "dislike" "like"   
## 
## $model.list$variables
## [1] "Salt" "Fat" 
## 
## 
## $err.fct
## function (x, y) 
## {
##     1/2 * (y - x)^2
## }
## <bytecode: 0x7ff0fafce710>
## <environment: 0x7ff0fafcbe88>
## attr(,"type")
## [1] "sse"
## 
## $act.fct
## function (x) 
## {
##     1/(1 + exp(-x))
## }
## <bytecode: 0x7ff0fafd2668>
## <environment: 0x7ff0fafcf7c0>
## attr(,"type")
## [1] "logistic"
## 
## $output.act.fct
## function (x) 
## {
##     1/(1 + exp(-x))
## }
## <bytecode: 0x7ff0fafd2668>
## <environment: 0x7ff0fafcfc58>
## attr(,"type")
## [1] "logistic"
## 
## $linear.output
## [1] FALSE
## 
## $data
##   Salt Fat Acceptance
## 1  0.9 0.2       like
## 2  0.1 0.1    dislike
## 3  0.4 0.2    dislike
## 4  0.5 0.2    dislike
## 5  0.5 0.4       like
## 6  0.8 0.3       like
## 
## $exclude
## NULL
## 
## $net.result
## $net.result[[1]]
##            [,1]        [,2]
## [1,] 0.04082509 0.952164832
## [2,] 0.99433531 0.005308353
## [3,] 0.96192445 0.034955858
## [4,] 0.89814149 0.092947014
## [5,] 0.08011918 0.907507259
## [6,] 0.02163683 0.974363882
## 
## 
## $weights
## $weights[[1]]
## $weights[[1]][[1]]
##           [,1]
## [1,] -5.255788
## [2,]  5.192349
## [3,]  8.755611
## 
## $weights[[1]][[2]]
##            [,1]      [,2]
## [1,]   5.399008 -5.461546
## [2,] -11.215341 11.079981
## 
## 
## 
## $generalized.weights
## $generalized.weights[[1]]
##            [,1]       [,2]      [,3]      [,4]
## [1,] -10.534657 -17.764092 10.407513 17.549694
## [2,]  -1.175648  -1.982439  1.161459  1.958513
## [3,]  -9.086191 -15.321612  8.976528 15.136692
## [4,] -11.924125 -20.107085 11.780211 19.864409
## [5,] -12.251926 -20.659840 12.104055 20.410493
## [6,]  -8.549011 -14.415791  8.445832 14.241804
## 
## 
## $startweights
## $startweights[[1]]
## $startweights[[1]][[1]]
##            [,1]
## [1,]  1.9136645
## [2,]  0.5361828
## [3,] -0.3453527
## 
## $startweights[[1]][[2]]
##            [,1]       [,2]
## [1,]  0.5440361 -0.8438382
## [2,] -0.3534844  0.8970273
## 
## 
## 
## $result.matrix
##                                [,1]
## error                   0.020900237
## reached.threshold       0.006140556
## steps                 122.000000000
## Intercept.to.1layhid1  -5.255787680
## Salt.to.1layhid1        5.192348611
## Fat.to.1layhid1         8.755610601
## Intercept.to.dislike    5.399007517
## 1layhid1.to.dislike   -11.215340657
## Intercept.to.like      -5.461545821
## 1layhid1.to.like       11.079980704
## 
## attr(,"class")
## [1] "nn"
plot(nn) # Plot the neuralnet network
nn$net.result[[1]]  # The probability of each sample to be a class (in alphabetical order)
##            [,1]        [,2]
## [1,] 0.04082509 0.952164832
## [2,] 0.99433531 0.005308353
## [3,] 0.96192445 0.034955858
## [4,] 0.89814149 0.092947014
## [5,] 0.08011918 0.907507259
## [6,] 0.02163683 0.974363882
                    # 1 = Dislike, 2 = Like, as printed in previous plot

Each row may not sum to 1!!!

We next make predictions. The predictions usually are done based on validation data. For demo purposes, predictions are just based on the training data:

predict(nn,newdata = df)
##            [,1]        [,2]
## [1,] 0.04082509 0.952164832
## [2,] 0.99433531 0.005308353
## [3,] 0.96192445 0.034955858
## [4,] 0.89814149 0.092947014
## [5,] 0.08011918 0.907507259
## [6,] 0.02163683 0.974363882

Next, we create a neural network using more parameters:

# We will fit a neural network with 2 input nodes, one hidden layer with 3 hidden nodes, and 2 output nodes. There will be 17 weights to estimate (9 + 8). The software can choose start weights.

nn <- neuralnet(
         Acceptance ~ Salt + Fat,   
         data = df, 
         hidden = c(3, 2), 
         linear.output = FALSE, # For classification, set it to FALSE
         err.fct = "sse", # The error function can be "sse" (default) or "ce"
         rep = 5  # You can repeat the training process many times (called epochs) with default = 1

       )

We then plot the structure of the neural net.

plot(nn, "best") # The best of the epochs (With the lowest error); the first happens to be best

Prediction for the training data, based on EACH repeat:

prediction(nn) 
## Data Error:  0;
## $rep1
##   Salt Fat    dislike       like
## 1  0.1 0.1 0.98942551 0.01619813
## 2  0.4 0.2 0.96406841 0.05050657
## 3  0.5 0.2 0.88420379 0.14527835
## 4  0.9 0.2 0.02898657 0.96405803
## 5  0.8 0.3 0.02663629 0.96668254
## 6  0.5 0.4 0.12628322 0.86335193
## 
## $rep2
##   Salt Fat    dislike        like
## 1  0.1 0.1 0.99567490 0.006917986
## 2  0.4 0.2 0.98643852 0.020645728
## 3  0.5 0.2 0.93661114 0.087724603
## 4  0.9 0.2 0.04335673 0.956495368
## 5  0.8 0.3 0.04036149 0.959276111
## 6  0.5 0.4 0.10002390 0.904167238
## 
## $rep3
##   Salt Fat    dislike       like
## 1  0.1 0.1 0.96904515 0.03210816
## 2  0.4 0.2 0.94626347 0.05323813
## 3  0.5 0.2 0.90485128 0.08981386
## 4  0.9 0.2 0.03500901 0.93984129
## 5  0.8 0.3 0.01641813 0.96945771
## 6  0.5 0.4 0.05410597 0.91198545
## 
## $rep4
##   Salt Fat    dislike        like
## 1  0.1 0.1 0.97831000 0.003819519
## 2  0.4 0.2 0.96310854 0.018469048
## 3  0.5 0.2 0.90438753 0.078801968
## 4  0.9 0.2 0.06319421 0.960531608
## 5  0.8 0.3 0.05413200 0.967204363
## 6  0.5 0.4 0.10909254 0.911685570
## 
## $rep5
##   Salt Fat    dislike        like
## 1  0.1 0.1 0.99714743 0.005786599
## 2  0.4 0.2 0.98088795 0.031751392
## 3  0.5 0.2 0.92739691 0.102357390
## 4  0.9 0.2 0.03147947 0.960024592
## 5  0.8 0.3 0.02164993 0.971408462
## 6  0.5 0.4 0.09409031 0.894759582
## 
## $data
##   Salt Fat dislike like
## 1  0.1 0.1       1    0
## 2  0.4 0.2       1    0
## 3  0.5 0.2       1    0
## 4  0.9 0.2       0    1
## 5  0.8 0.3       0    1
## 6  0.5 0.4       0    1

Prediction for the training data, based on the best epoch:

predict(nn,newdata = df, 
        rep = which.min(nn$result.matrix[1,])) # Prediction based on best epoch
##            [,1]        [,2]
## [1,] 0.04335673 0.956495368
## [2,] 0.99567490 0.006917986
## [3,] 0.98643852 0.020645728
## [4,] 0.93661114 0.087724603
## [5,] 0.10002390 0.904167238
## [6,] 0.04036149 0.959276111

Next, let’s create a neural network for classifying species based on the iris data.

We first split the data to training and validation sets.

library(neuralnet)

set.seed(123)

n = nrow(iris)
train.idx = sample(1:n, n*0.6)
train = iris[train.idx, ]
validation=iris[-train.idx, ]

For classification, the response can be a single categorical variable or all dummy variables of the categorical variable. The propensities for each record may not add up to 1. To make the propensities for each record sum to 1, the softmax activation function can be used to the output nodes.

Method 1: Using Species as a single response variable

nn <- neuralnet(formula = Species ~ Petal.Length + Petal.Width, 
                data = train, 
                hidden = c(1),           # One hidden lay with one node
                act.fct = "logistic",    # The activation function can also be defined by the user.
                linear.output = FALSE,
                rep = 10,         # the number of repetitions for the neural network's training. Also known as Epoch.
                                  # The epoch with the least error is reported. 
                                  # Use code: which.min(nn$result.matrix[1,]) to get the best epoch. 
                                  # The results are obtained by nn$result.matrix[, k], assuming the best epoch is k.
                lifesign="full",  # A string specifying how much the function will print during                                                            # the calculation of the neural network. 'none', 'minimal' or 'full'.
               )
## hidden: 1    thresh: 0.01    rep:  1/10    steps:    1000    min thresh: 0.0693234628390257
##                                                      2000    min thresh: 0.0519463377386366
##                                                      3000    min thresh: 0.0519463377386366
##                                                      4000    min thresh: 0.0519463377386366
##                                                      5000    min thresh: 0.0519463377386366
##                                                      6000    min thresh: 0.0519463377386366
##                                                      7000    min thresh: 0.0519463377386366
##                                                      8000    min thresh: 0.0519463377386366
##                                                      9000    min thresh: 0.0519463377386366
##                                                     10000    min thresh: 0.0519463377386366
##                                                     11000    min thresh: 0.0519463377386366
##                                                     12000    min thresh: 0.0519463377386366
##                                                     13000    min thresh: 0.0456772449332913
##                                                     14000    min thresh: 0.03821848042476
##                                                     15000    min thresh: 0.03821848042476
##                                                     16000    min thresh: 0.0358264244659994
##                                                     17000    min thresh: 0.033652909285183
##                                                     18000    min thresh: 0.0297125639894362
##                                                     19000    min thresh: 0.0259609214923759
##                                                     20000    min thresh: 0.0250493396090673
##                                                     21000    min thresh: 0.0181340650419817
##                                                     22000    min thresh: 0.0181340650419817
##                                                     23000    min thresh: 0.0174382299782745
##                                                     24000    min thresh: 0.014225897639359
##                                                     25000    min thresh: 0.0134697308857342
##                                                     26000    min thresh: 0.0116228906033433
##                                                     26601    error: 8.38439  time: 4.1 secs
## hidden: 1    thresh: 0.01    rep:  2/10    steps:    1000    min thresh: 0.0979214261659785
##                                                      2000    min thresh: 0.0278589902799936
##                                                      2849    error: 8.60277  time: 0.39 secs
## hidden: 1    thresh: 0.01    rep:  3/10    steps:    1000    min thresh: 0.0590937852207485
##                                                      2000    min thresh: 0.0528871603580913
##                                                      3000    min thresh: 0.0528871603580913
##                                                      4000    min thresh: 0.0528871603580913
##                                                      5000    min thresh: 0.0528871603580913
##                                                      6000    min thresh: 0.0528871603580913
##                                                      7000    min thresh: 0.0528871603580913
##                                                      8000    min thresh: 0.0528871603580913
##                                                      9000    min thresh: 0.0528871603580913
##                                                     10000    min thresh: 0.0528871603580913
##                                                     11000    min thresh: 0.0528871603580913
##                                                     12000    min thresh: 0.0528871603580913
##                                                     13000    min thresh: 0.0528871603580913
##                                                     14000    min thresh: 0.0500643442096174
##                                                     15000    min thresh: 0.0387073841158637
##                                                     16000    min thresh: 0.0374988918653775
##                                                     17000    min thresh: 0.0374988918653775
##                                                     18000    min thresh: 0.0347571431954047
##                                                     19000    min thresh: 0.0318978401856138
##                                                     20000    min thresh: 0.0304210302168549
##                                                     21000    min thresh: 0.0233715461820567
##                                                     22000    min thresh: 0.0200186040372769
##                                                     23000    min thresh: 0.0200186040372769
##                                                     24000    min thresh: 0.0181789003832081
##                                                     25000    min thresh: 0.0180512755706754
##                                                     26000    min thresh: 0.0168140008364111
##                                                     27000    min thresh: 0.0162060697836828
##                                                     28000    min thresh: 0.0142074123917758
##                                                     29000    min thresh: 0.0120522538353961
##                                                     30000    min thresh: 0.0120522538353961
##                                                     31000    min thresh: 0.0118252484775589
##                                                     32000    min thresh: 0.0115714762342684
##                                                     33000    min thresh: 0.0107429979006735
##                                                     34000    min thresh: 0.0107429979006735
##                                                     35000    min thresh: 0.010167657384092
##                                                     35612    error: 8.3333   time: 4.83 secs
## hidden: 1    thresh: 0.01    rep:  4/10    steps:    1000    min thresh: 0.0459093822974279
##                                                      2000    min thresh: 0.0233748149915351
##                                                      2891    error: 8.59983  time: 0.31 secs
## hidden: 1    thresh: 0.01    rep:  5/10    steps:    1000    min thresh: 0.0721032846847956
##                                                      2000    min thresh: 0.022415291319459
##                                                      2916    error: 8.59606  time: 0.34 secs
## hidden: 1    thresh: 0.01    rep:  6/10    steps:    1000    min thresh: 0.113709368802841
##                                                      2000    min thresh: 0.0639886911360698
##                                                      3000    min thresh: 0.0627590906746204
##                                                      4000    min thresh: 0.0627590906746204
##                                                      5000    min thresh: 0.0560536970739496
##                                                      6000    min thresh: 0.0556883622658409
##                                                      7000    min thresh: 0.0525089639570909
##                                                      8000    min thresh: 0.0525089639570909
##                                                      9000    min thresh: 0.0514294063878954
##                                                     10000    min thresh: 0.0500305447046641
##                                                     11000    min thresh: 0.0498152536700875
##                                                     12000    min thresh: 0.0498152536700875
##                                                     13000    min thresh: 0.04797567127621
##                                                     14000    min thresh: 0.0436514950490174
##                                                     15000    min thresh: 0.0332139280726573
##                                                     16000    min thresh: 0.0323967372769761
##                                                     17000    min thresh: 0.0247469205783873
##                                                     18000    min thresh: 0.0245212556974425
##                                                     19000    min thresh: 0.0245212556974425
##                                                     20000    min thresh: 0.0189544108801703
##                                                     21000    min thresh: 0.0182232209577058
##                                                     22000    min thresh: 0.0150606088062705
##                                                     23000    min thresh: 0.0150195183262021
##                                                     24000    min thresh: 0.0123257167569946
##                                                     25000    min thresh: 0.0110901290311461
##                                                     26000    min thresh: 0.0110901290311461
##                                                     27000    min thresh: 0.011061465737797
##                                                     28000    min thresh: 0.0105852363719989
##                                                     28272    error: 8.37253  time: 3.56 secs
## hidden: 1    thresh: 0.01    rep:  7/10    steps:    1000    min thresh: 0.0711648838899715
##                                                      2000    min thresh: 0.0711648838899715
##                                                      3000    min thresh: 0.0711648838899715
##                                                      4000    min thresh: 0.0711648838899715
##                                                      5000    min thresh: 0.0711648838899715
##                                                      6000    min thresh: 0.0711648838899715
##                                                      7000    min thresh: 0.0711648838899715
##                                                      8000    min thresh: 0.0711648838899715
##                                                      9000    min thresh: 0.0711648838899715
##                                                     10000    min thresh: 0.0676932547784622
##                                                     11000    min thresh: 0.0675153965340415
##                                                     12000    min thresh: 0.0642576922124582
##                                                     13000    min thresh: 0.060613521511222
##                                                     14000    min thresh: 0.0574724011995649
##                                                     15000    min thresh: 0.0495098262308411
##                                                     16000    min thresh: 0.0459763913959189
##                                                     17000    min thresh: 0.0351979262191944
##                                                     18000    min thresh: 0.0351979262191944
##                                                     19000    min thresh: 0.0324793570583117
##                                                     20000    min thresh: 0.028313919803666
##                                                     21000    min thresh: 0.0241876947066991
##                                                     22000    min thresh: 0.0222986374041807
##                                                     23000    min thresh: 0.0182002581576178
##                                                     24000    min thresh: 0.0182002581576178
##                                                     25000    min thresh: 0.0159734597524657
##                                                     26000    min thresh: 0.0154604399719795
##                                                     27000    min thresh: 0.0145422943227187
##                                                     28000    min thresh: 0.012190343957056
##                                                     29000    min thresh: 0.0115437918556836
##                                                     30000    min thresh: 0.0115437918556836
##                                                     31000    min thresh: 0.0115437918556836
##                                                     32000    min thresh: 0.0106651101999807
##                                                     33000    min thresh: 0.0106651101999807
##                                                     34000    min thresh: 0.0106651101999807
##                                                     34380    error: 8.33962  time: 4.22 secs
## hidden: 1    thresh: 0.01    rep:  8/10    steps:      26    error: 29.82225 time: 0 secs
## hidden: 1    thresh: 0.01    rep:  9/10    steps:    1000    min thresh: 0.0339684291135093
##                                                      2000    min thresh: 0.0192388157247502
##                                                      2802    error: 8.61046  time: 0.33 secs
## hidden: 1    thresh: 0.01    rep: 10/10    steps:    1000    min thresh: 0.0615492007580969
##                                                      2000    min thresh: 0.0229368275168006
##                                                      2797    error: 8.60804  time: 0.32 secs

Plot the network. If rep=“best”, the repetition (or epoch) with the smallest error will be plotted. If not stated all repetitions will be plotted, each in a separate window.

plot(nn, rep = "best") # For help, use ?plot.nn

We then make predictions based on the validation data:

pred = predict(nn, 
               newdata = validation, 
               rep = which.min(nn$result.matrix[1,]) # Prediction based on the best epoch
                                                     # with lowest error
              )
head(pred) # Display propensity scores
##         [,1]      [,2]         [,3]
## 1  0.9703730 0.4362782 9.471741e-38
## 2  0.9703730 0.4362782 9.471741e-38
## 3  0.9704065 0.4362785 9.468487e-38
## 5  0.9703730 0.4362782 9.471741e-38
## 11 0.9703355 0.4362779 9.475381e-38
## 15 0.9704364 0.4362788 9.465578e-38
# Get predicted labels
predicted.class = apply(pred, 1, which.max) %>% as.numeric()

predicted.class # Note that: The columns are in the alphabetical order of the different values in the response variable
##  [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 3 2 2 2 2 2 2 2 2 3 2 2 2 2 2
## [39] 3 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
# Convert predicted.class to factor in order to create confusion matrix
predicted = factor(predicted.class, levels = 1:3, labels = c("setosa", "versicolor", "virginica"))

# Convert actual column in test data to factor in order to create confusion matrix
actual = validation$Species %>% factor(levels = c("setosa", "versicolor", "virginica"))

caret::confusionMatrix(predicted, actual)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         20          0         0
##   versicolor      0         21         0
##   virginica       0          3        16
## 
## Overall Statistics
##                                           
##                Accuracy : 0.95            
##                  95% CI : (0.8608, 0.9896)
##     No Information Rate : 0.4             
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.9247          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.8750           1.0000
## Specificity                 1.0000            1.0000           0.9318
## Pos Pred Value              1.0000            1.0000           0.8421
## Neg Pred Value              1.0000            0.9231           1.0000
## Prevalence                  0.3333            0.4000           0.2667
## Detection Rate              0.3333            0.3500           0.2667
## Detection Prevalence        0.3333            0.3500           0.3167
## Balanced Accuracy           1.0000            0.9375           0.9659

Method 2: Using all dummy variables for Species as joint responses.

nn2 <- neuralnet(
         (Species == "setosa") + (Species == "versicolor") + (Species == "virginica") ~ Petal.Length + Petal.Width,   
         data = train, 
         hidden = c(1), 
         linear.output = FALSE,
         rep = 10
       )

plot(nn2, rep = "best")

Next, we use the book data “accidentsnn.csv” to create a neural network. Each row of the data table corresponds to a US automobile accident classified by its level of severity as no injury, injury, or fatality. The purpose is to develop a system for quickly classifying the severity of an accident, based on initial reports and associated data in the system (some of which rely on GPS-assisted reporting). Such a system could be used to assign emergency response team priorities. Here is a description of the variables:

  • ALCHL_I: Presence (1) or absence (2) of alcohol

  • PROFIL_I_R: Profile of the road way (level = 1, other = 0)

  • SUR_COND: Surface condition of the road (dry = 1, wet = 2, snow/slush = 3, ice = 4, unknown = 9)

  • VEH_INVL: Number of vehicles involved

  • MAX_SEV_IR: Presence of injuries/fatalities (no injury = 0, injury = 1, fatality = 2)

Data pre-processing:

accidents.df <- read.csv("/Users/home/Documents/Zhang/Stat415.515.615/DMBA-R-datasets/accidentsnn.csv")

# Dummify categorical variables
accidents.df$SUR_COND_1 = (accidents.df$SUR_COND==1)
accidents.df$SUR_COND_2 = (accidents.df$SUR_COND==2)
accidents.df$SUR_COND_3 = (accidents.df$SUR_COND==3)
accidents.df$SUR_COND_4 = (accidents.df$SUR_COND==4)
accidents.df$SUR_COND_9 = (accidents.df$SUR_COND==9)

accidents.df$MAX_SEV_IR_0 = (accidents.df$MAX_SEV_IR==0)
accidents.df$MAX_SEV_IR_1 = (accidents.df$MAX_SEV_IR==1)
accidents.df$MAX_SEV_IR_2 = (accidents.df$MAX_SEV_IR==2)

accidents.df$ALCHL_I_1 = (accidents.df$ALCHL_I==1)
accidents.df$ALCHL_I_2 = (accidents.df$ALCHL_I==2)

Data partition:

set.seed(2)
n = nrow(accidents.df)
train.idx = sample(1:n, n*0.6)
train = accidents.df[train.idx, ]
validation=accidents.df[-train.idx, ]

We will run a neural network with 2 hidden nodes. Use hidden= with a vector of integers specifying number of hidden nodes in each layer. The dummy variable “SUR_COND_9” is not used in the model, since it is perfectly. correlated to other dummy variables.

nn <- neuralnet(
  formula = MAX_SEV_IR_0 + MAX_SEV_IR_1 + MAX_SEV_IR_2 ~ 
                  ALCHL_I_1 + PROFIL_I_R + VEH_INVL + SUR_COND_1 + SUR_COND_2 
                + SUR_COND_3 + SUR_COND_4, data = train, hidden = c(2))

Now, we plot the network:

plot(nn, rep = "best")

Create two confusion matrices, one for training data and one for validation data. The purpose is to compare the results to see if they are similar. If yes, there is no serious overfitting issue. If the results based on the validation is much worse, there is a serious overfitting issue.

Prediction based on the training data:

train.prediction <- predict(nn, train)
train.prediction.class <- apply(train.prediction,1,which.max) - 1
caret::confusionMatrix(factor(train.prediction.class), factor(train$MAX_SEV_IR))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1   2
##          0 334   0  35
##          1   0 164  34
##          2   1   7  24
## 
## Overall Statistics
##                                          
##                Accuracy : 0.8715         
##                  95% CI : (0.842, 0.8972)
##     No Information Rate : 0.5593         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.7675         
##                                          
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2
## Sensitivity            0.9970   0.9591  0.25806
## Specificity            0.8674   0.9206  0.98419
## Pos Pred Value         0.9051   0.8283  0.75000
## Neg Pred Value         0.9957   0.9825  0.87831
## Prevalence             0.5593   0.2855  0.15526
## Detection Rate         0.5576   0.2738  0.04007
## Detection Prevalence   0.6160   0.3306  0.05342
## Balanced Accuracy      0.9322   0.9398  0.62113

Prediction based on the validation data:

train.prediction <- predict(nn, train)
validation.prediction <- predict(nn, validation)
validation.prediction.class <-apply(validation.prediction,1,which.max)-1
caret::confusionMatrix(factor(validation.prediction.class), factor(validation$MAX_SEV_IR))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1   2
##          0 216   0  20
##          1   0 123  23
##          2   0   5  13
## 
## Overall Statistics
##                                           
##                Accuracy : 0.88            
##                  95% CI : (0.8441, 0.9102)
##     No Information Rate : 0.54            
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.7851          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2
## Sensitivity            1.0000   0.9609   0.2321
## Specificity            0.8913   0.9154   0.9855
## Pos Pred Value         0.9153   0.8425   0.7222
## Neg Pred Value         1.0000   0.9803   0.8874
## Prevalence             0.5400   0.3200   0.1400
## Detection Rate         0.5400   0.3075   0.0325
## Detection Prevalence   0.5900   0.3650   0.0450
## Balanced Accuracy      0.9457   0.9382   0.6088

We can also use one response variable (converted to categorical using the factor() function). Other code remains the same.

nn <- neuralnet(
  formula = factor(MAX_SEV_IR) ~ 
                  ALCHL_I_1 + PROFIL_I_R + VEH_INVL + SUR_COND_1 + SUR_COND_2 
                + SUR_COND_3 + SUR_COND_4, 
  data = train, 
  hidden = 2)

plot(nn, rep="best")

The two matrices are

train.prediction <- predict(nn, train)
train.prediction.class <- apply(train.prediction,1,which.max)-1
caret::confusionMatrix(factor(train.prediction.class), factor(train$MAX_SEV_IR))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1   2
##          0 334   0  35
##          1   1 164  37
##          2   0   7  21
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8664          
##                  95% CI : (0.8366, 0.8927)
##     No Information Rate : 0.5593          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.758           
##                                           
##  Mcnemar's Test P-Value : 3.36e-12        
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2
## Sensitivity            0.9970   0.9591  0.22581
## Specificity            0.8674   0.9112  0.98617
## Pos Pred Value         0.9051   0.8119  0.75000
## Neg Pred Value         0.9957   0.9824  0.87391
## Prevalence             0.5593   0.2855  0.15526
## Detection Rate         0.5576   0.2738  0.03506
## Detection Prevalence   0.6160   0.3372  0.04674
## Balanced Accuracy      0.9322   0.9351  0.60599
validation.prediction <- predict(nn, validation)
validation.prediction.class <-apply(validation.prediction,1,which.max)-1
caret::confusionMatrix(factor(validation.prediction.class), factor(validation$MAX_SEV_IR))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1   2
##          0 216   0  20
##          1   0 124  27
##          2   0   4   9
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8725          
##                  95% CI : (0.8358, 0.9036)
##     No Information Rate : 0.54            
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.7707          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2
## Sensitivity            1.0000   0.9688   0.1607
## Specificity            0.8913   0.9007   0.9884
## Pos Pred Value         0.9153   0.8212   0.6923
## Neg Pred Value         1.0000   0.9839   0.8786
## Prevalence             0.5400   0.3200   0.1400
## Detection Rate         0.5400   0.3100   0.0225
## Detection Prevalence   0.5900   0.3775   0.0325
## Balanced Accuracy      0.9457   0.9347   0.5745

11.4 Preprocessing the Data

For numerical features (predictors) and quantitative responses, use the z-scores (or called the standardized scores) or convert them into the range [0, 1] by applying the transformation \(x' = \frac{x-min}{max-min}\). For each nominal categorical feature of \(m\) categories, use \(m-1\) dummy variables. For each ordinal categorical feature, map its categories to appropriate numerical values between 0 and 1.

The model.matrix() function from base R is very useful when creating dummy variables. We demonstrate this function with the diamonds data from ggplot2 package:

head(diamonds)
## # A tibble: 6 × 10
##   carat cut       color clarity depth table price     x     y     z
##   <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1  0.23 Ideal     E     SI2      61.5    55   326  3.95  3.98  2.43
## 2  0.21 Premium   E     SI1      59.8    61   326  3.89  3.84  2.31
## 3  0.23 Good      E     VS1      56.9    65   327  4.05  4.07  2.31
## 4  0.29 Premium   I     VS2      62.4    58   334  4.2   4.23  2.63
## 5  0.31 Good      J     SI2      63.3    58   335  4.34  4.35  2.75
## 6  0.24 Very Good J     VVS2     62.8    57   336  3.94  3.96  2.48

We dummify all non-numeric variables:

mm = model.matrix(~., data = diamonds)

mm = mm[, -1] # This will remove the first column which is the constant "1" column.

Note that the ordinal variables are treated as numerical and there ar no dummy variables created.

Now, we create neural networks for iris data

# use preProcess() from the caret package to standardize numerical variables.
# This function creates a recipe for normalizing. The same recipe is used for normalizing the 
# numeric variables in validation set as well.

set.seed(314)

n = nrow(iris)

train_idx <- sample(nrow(iris), n * 0.9)
train <- iris[train_idx, ]
validation <- iris[-train_idx, ]

# Normalize the training set to get a recipe
norm.values = caret::preProcess(train, method = "range")

# Normalize both training and validation data. 
train.norm <- predict(norm.values, train)
validation.norm <- predict(norm.values, validation)

# Train a neural network.
nn <- neuralnet(
  formula = Species ~ Petal.Length + Petal.Width, 
  data = train.norm, 
  hidden = c(2,1),  # Two hidden layers with 2 and 1 neurons, respectively
  rep = 5,  # the number of repetitions for the neural network's training, also known as Epoch.
            # The epoch with the least error is reported. 
            # Use code: which.min(nn$result.matrix[1,]) to get the best epoch. 
            # The results are obtained by nn$result.matrix[, k], assuming the best epoch is k.
  linear.output = FALSE
)

nn$result.matrix # Which epoch has the least error? 4th epoch.
##                                   [,1]          [,2]          [,3]
## error                     1.416554e+01  1.337887e+01  1.414219e+01
## reached.threshold         9.663103e-03  9.813818e-03  9.998812e-03
## steps                     7.530000e+03  1.227200e+04  1.390300e+04
## Intercept.to.1layhid1     1.215201e+01 -9.000757e+00  8.955731e-01
## Petal.Length.to.1layhid1 -9.675856e+00  2.010636e+01 -5.453811e-01
## Petal.Width.to.1layhid1  -7.135273e+00 -1.195876e+01 -9.995267e-01
## Intercept.to.1layhid2    -5.856720e+00 -5.784187e+00 -1.177202e+01
## Petal.Length.to.1layhid2  2.346568e+01 -2.685730e+00  8.339842e+00
## Petal.Width.to.1layhid2   1.552286e+01  7.619169e+00  5.771140e+00
## Intercept.to.2layhid1    -2.797774e-01 -5.599576e+00 -1.263752e+00
## 1layhid1.to.2layhid1     -1.093552e+01  2.206730e+01 -6.106965e+00
## 1layhid2.to.2layhid1      6.428924e+00  4.353128e+01  2.829484e+01
## Intercept.to.setosa       3.687887e+00  9.041792e+00  9.683349e+00
## 2layhid1.to.setosa       -7.511737e+02 -1.228262e+03 -1.390764e+03
## Intercept.to.versicolor  -1.786082e-01 -1.432967e-01 -1.655331e-01
## 2layhid1.to.versicolor   -3.317453e+00 -2.846573e+00 -3.242086e+00
## Intercept.to.virginica   -3.320662e+00 -4.619196e+01 -3.433703e+00
## 2layhid1.to.virginica     2.338577e+01  2.688089e+02  2.359626e+01
##                                   [,4]          [,5]
## error                     1.415099e+01  1.418315e+01
## reached.threshold         9.826511e-03  9.914601e-03
## steps                     1.261000e+04  7.098000e+03
## Intercept.to.1layhid1    -4.602304e-01  1.168683e+01
## Petal.Length.to.1layhid1  5.938058e-01 -9.228992e+00
## Petal.Width.to.1layhid1   1.042218e+00 -6.816663e+00
## Intercept.to.1layhid2    -1.196228e+01  3.116412e+01
## Petal.Length.to.1layhid2  8.475290e+00 -1.495888e+02
## Petal.Width.to.1layhid2   5.932438e+00 -9.271434e+01
## Intercept.to.2layhid1    -7.690033e+00  6.542596e+00
## 1layhid1.to.2layhid1      5.526161e+00 -1.129602e+01
## 1layhid2.to.2layhid1      2.796603e+01 -1.991543e+02
## Intercept.to.setosa       9.439256e+00  3.645002e+00
## 2layhid1.to.setosa       -1.260703e+03 -7.098994e+02
## Intercept.to.versicolor  -1.648692e-01 -1.794058e-01
## 2layhid1.to.versicolor   -3.194180e+00 -3.263397e+00
## Intercept.to.virginica   -3.445343e+00 -3.342743e+00
## 2layhid1.to.virginica     2.330633e+01  2.306248e+01
plot(nn, rep = "best") # For help, use ?plot.nn

# Predictions based on the best training epoch
pred.valid = predict(nn, newdata = validation.norm, rep = 4) # For documentation, type ? predict.nn in console
pred.valid
##             [,1]       [,2]       [,3]
## 9   9.831210e-01 0.45549645 0.03402693
## 23  9.872371e-01 0.45567474 0.03385496
## 26  9.804653e-01 0.45540293 0.03411748
## 48  9.831210e-01 0.45549645 0.03402693
## 62  7.339478e-09 0.44121309 0.05095158
## 69  2.812182e-13 0.43487052 0.06084905
## 77  8.769188e-16 0.43128093 0.06723856
## 88  7.090752e-07 0.44407012 0.04701729
## 93  6.949949e-04 0.44838282 0.04162949
## 96  1.469980e-04 0.44740917 0.04279090
## 98  3.435879e-06 0.44505740 0.04572727
## 101 0.000000e+00 0.03360009 1.00000000
## 112 0.000000e+00 0.05375363 0.99999991
## 126 0.000000e+00 0.03385530 1.00000000
## 129 0.000000e+00 0.03368239 1.00000000

When numerical features are highly skewed, as in business applications, it’s suggested to take a log-transformation before normalization.

11.5 Neural Networks for Regression

When using neural networks for regression, the range of the output activation function should be consistent with the range of the quantitative response variable, or no output activation function is used (setting linear.output = TRUE). It is again suggested that variables be pre-processed (start with training data) with the function “preProcess” so that the range of each numeric variable is from 0 to 1, or from -1 to 1, or around 0.

When making prediction for a quantitative response variable, the predicted values should be convert back to the original scale.

The following is to use a neural network for regressing Sepal.Length with all other variables in the iris data as predictors. Since neural networks only accept numeric predictors, the Species variable needs to be dummified and only 2 of these dummy variables are used.

iris1 = iris  # Create a copy for iris data

# Create 3 dummy variables to replace the Species variables. Only two of these will be used.
iris1$Species_setosa = (iris1$Species=="setosa")*1
iris1$Species_versicolor = (iris1$Species=="versicolor")*1
iris1$Species_virginica = (iris1$Species=="virginica")*1

head(iris1) # Dummy variables are successfully created
##   Sepal.Length Sepal.Width Petal.Length Petal.Width Species Species_setosa
## 1          5.1         3.5          1.4         0.2  setosa              1
## 2          4.9         3.0          1.4         0.2  setosa              1
## 3          4.7         3.2          1.3         0.2  setosa              1
## 4          4.6         3.1          1.5         0.2  setosa              1
## 5          5.0         3.6          1.4         0.2  setosa              1
## 6          5.4         3.9          1.7         0.4  setosa              1
##   Species_versicolor Species_virginica
## 1                  0                 0
## 2                  0                 0
## 3                  0                 0
## 4                  0                 0
## 5                  0                 0
## 6                  0                 0
# Partition data into training and validation sets
set.seed(123)
n = nrow(iris1)
train_idx <- sample(nrow(iris1), n * 0.9)
train <- iris1[train_idx, ]
valid <- iris1[-train_idx, ]

# Normalization of numerical data using preProcess() from the caret package.
# This function creates a recipe used for normalizing ALL numerical data (including response) 
# in both training and validation sets
norm.values = caret::preProcess(train, method = "range")  # Normalize to [0,1]

# Note the function predict() does the normalizing job. 
train.norm <- predict(norm.values, train)
valid.norm <- predict(norm.values, valid)

# Fit a neural network
nn <- neuralnet(
  formula = Sepal.Length ~ Petal.Length + Petal.Width + Species_setosa + Species_versicolor, # Use only 2 dummies
  data = train.norm, 
  hidden = c(1, 2),  # We use 2 hidden layers with layer 1 having 1 node and layer 2 having 2 nodes
  linear.output = TRUE # For regression, set "linear.output" to TRUE; set it to FALSE for classification
)

plot(nn, rep = "best") 

# Predictions
pred.train = predict(nn, newdata = train.norm) # based on normalized training data
head(pred.train)
##          [,1]
## 14  0.1537110
## 50  0.1866937
## 118 0.9284185
## 43  0.1750122
## 150 0.5091412
## 148 0.5350192
pred.valid = predict(nn, newdata = valid.norm) # based on normalized validation data
head(pred.valid) 
##         [,1]
## 1  0.1866937
## 18 0.1865793
## 28 0.1990731
## 33 0.1991943
## 48 0.1866937
## 55 0.5237451
# Convert predicted values to the original scale, since the response variable has been normalized.
Min = min(train$Sepal.Length) # Keep in mind that we normalized data 
                              # based on the minimums and maximums of the train data
Max = max(train$Sepal.Length)

pred.ori.valid = as.numeric(pred.valid)*(Max-Min) + Min

head(pred.ori.valid)
## [1] 4.972097 4.971686 5.016663 5.017100 4.972097 6.185482
plot(x = valid$Sepal.Length, y = pred.ori.valid)

# Display accuracy measures
forecast::accuracy(pred.ori.valid, valid$Sepal.Length)
##                 ME      RMSE       MAE      MPE     MAPE
## Test set 0.1508507 0.2667496 0.2400849 2.207179 4.094318

If we want to set a category of a particular nominal variable as the reference, we can use the relevel() function (not for ordinal variable!!!):

D= iris
D$Species=relevel(D$Species, ref = "versicolor")

# or do
levels(D$Species)[1] = "versicolor"   # set the reference category to "versicolor"

11.6 Advantages and Weaknesses of Neural Networks

Neural networks have high tolerance to noisy data and the ability to capture highly complicated relationship between the predictors and an outcome variable.

Neural networks do not have a built-in variable selection mechanism. This means that there is a need to use variables pre-selected by other models or methods, such as decision trees and PCA.

Neural networks may obtain weights that lead to a local optimum thus do not provide the best fit to the training data.

Neural networks are relatively heavy on computation time.

Neural networks are less interpretable than other supervised learning methods such as decision trees. Thus, the neural network is called a black-box model while the decision tree is called a white-box model.

Uplift Modeling (Chapter 13)

Ensembles and uplift modeling are two useful approaches for improving predictive power. An ensemble combines multiple supervised models (classification or prediction) into a “super-model.” Ensemble methods include bagging and boosting.

In uplift modeling, we combine supervised modeling with A-B testing, which is a simple randomized experiment with two groups. The idea of A-B testing is to test one treatment (or intervention) against another, or a treatment against a control. In a medical trial, a treatment can be a drug, device, or therapy; in marketing, it can be an offering to a consumer. An important element of A-B testing is random allocation (randomly assign individuals to treatments). The traditional 2-proportion \(z\)-test or 2-mean \(t\) test based on data from a randomized experiment tells you which treatment does better on average, but says nothing about which treatment does better for which individual.

Uplift modeling aims at predicting the causal effect of an action such as a medical treatment or a marketing campaign on each individual. It allows us to predict how much positive impact the treatment or campaign has on each subject (such as consumer, voter, and patient), holding other variables constant.

13.2 Uplift (Persuasion) Modeling

Uplift modeling is used mainly in marketing. Direct marketing affords the marketer the ability to invite and monitor direct responses from consumers. With predictive modeling, consumers that are most likely to respond can be targeted.

Here is some introductions to uplift modeling: https://zyabkina.com/challenges-of-uplift-modeling-in-marketing/ and https://ambiata.com/blog/2020-07-07-uplift-modeling/#:~:text=A%20simple%20way%20to%20evaluate,uplift%20value%20for%20each%20customer and https://towardsdatascience.com/why-every-marketer-should-consider-uplift-modeling-1090235572ec (with some technical details)

Uplift modeling has been used in political campaigns to determine whether to send someone a persuasion message, or just leave them alone. The following is an example.

A campaign director conducts a survey of 10000 voters to determine their inclination to vote Democratic. First, the 10000 voters are split into two groups of 5000 each. A message promoting Smith is mailed to each individual in the first group (the treatment group, indicated by 1). No message is mailed to individuals in the second group (the control group, indicated by 0). The goal is to measure the change in opinion after the message is sent out, relative to the no-message control group.

A post-message survey of the same sample of 10000 voters is then conducted to measure whether each voter’s opinion of Smith has shifted in a positive direction. A binary variable, Moved_AD, in the data indicates whether opinion has moved in a Democratic direction (1) or not (0).

Here is a link to the data dictionary: https://drive.google.com/file/d/1_W9_jj45FH2PTVDey8BmMC-NKatNifU5/view?pli=1

The variables that will be used from the data are:

  • Age: Voter age in years

  • NH_White: Neighborhood average of % non-Hispanic white in household

  • Comm_PT: Neighborhood % of workers who take public transit

  • H_F1: Single female household (1 = yes)

  • Reg_Days: Days since voter registered at current address

  • PR_Pelig: Vote in what % of non-presidential primaries

  • E_Pelig: Voted in what % of any primaries

  • Political_C: Is there a political contributor in the home (1 = yes)

Here is a glimpse of the data:

voter.df <- read.csv("/Users/Home/Documents/Zhang/Stat415.515.615/DMBA-R-datasets/Voter-Persuasion.csv")
dim(voter.df) # 10000 rows and 79 columns
## [1] 10000    79
names(voter.df)
##  [1] "VOTER_ID"      "SET_NO"        "OPP_SEX"       "AGE"          
##  [5] "HH_ND"         "HH_NR"         "HH_NI"         "MED_AGE"      
##  [9] "NH_WHITE"      "NH_AA"         "NH_ASIAN"      "NH_MULT"      
## [13] "HISP"          "COMM_LT10"     "COMM_609P"     "MED_HH_INC"   
## [17] "COMM_CAR"      "COMM_CP"       "COMM_PT"       "COMM_WALK"    
## [21] "KIDS"          "M_MAR"         "F_MAR"         "ED_4COL"      
## [25] "GENDER_F"      "GENDER_M"      "H_AFDLN3P"     "H_F1"         
## [29] "H_M1"          "H_MFDLN3P"     "PARTY_D"       "PARTY_I"      
## [33] "PARTY_R"       "VPP_08"        "VPP_12"        "VPR_08"       
## [37] "VPR_10"        "VPR_12"        "VG_04"         "VG_06"        
## [41] "VG_08"         "VG_10"         "VG_12"         "PP_PELIG"     
## [45] "PR_PELIG"      "AP_PELIG"      "G_PELIG"       "E_PELIG"      
## [49] "NL5G"          "NL3PR"         "NL5AP"         "NL2PP"        
## [53] "REG_DAYS"      "UPSCALEBUY"    "UPSCALEMAL"    "UPSCALEFEM"   
## [57] "BOOKBUYERI"    "FAMILYMAGA"    "FEMALEORIE"    "RELIGIOUSM"   
## [61] "GARDENINGM"    "CULINARYIN"    "HEALTHFITN"    "DOITYOURSE"   
## [65] "FINANCIALM"    "RELIGIOUSC"    "POLITICALC"    "MEDIANEDUC"   
## [69] "CAND1S"        "CAND2S"        "MESSAGE_A"     "MESSAGE_A_REV"
## [73] "I3"            "CAND1_UND"     "CAND2_UND"     "MOVED_AD"     
## [77] "MOVED_A"       "opposite"      "Partition"
head(voter.df[,1:5])
##   VOTER_ID SET_NO OPP_SEX AGE HH_ND
## 1   193801      2       0  28     1
## 2   627701      1       0  53     2
## 3   306924      2       0  68     2
## 4   547609      1       0  66     0
## 5   141105      3       0  23     0
## 6   334787      1       0  49     2
# Apply the mean function to find the proportion of voters who 
# moved in a Democratic direction in each of the two groups
aggregate(MOVED_A~MESSAGE_A, voter.df, mean) 
##   MESSAGE_A MOVED_A
## 1         0  0.3444
## 2         1  0.4024

For those voters (2988 + 2012 = 5000) who received a message promoting Smith, 40.2% them moved in a Democratic direction, while this number was 34.4% for those who did not receive such a message. Overall, the lift from the message is \(40.2\%-34.4\%\) or 5.8%.

How to compute the uplift for each individual in a new dataset after implementing an A-B test? Here are the steps:

  • Fit a predictive model (multiple regression/logistic regression/tree/k-NN/nn, etc.) based on the training data with a treatment variable and other features as predictors.

  • Compute the predicted values based on the validation data.

  • Reverse the value of the treatment variable in the validation data and re-compute the predicted values based on the validation data.

  • Estimate the uplift for each individual by subtracting the predictions (treatment minus control).

  • Compute the Qini coefficient and plot the incremental gains from the fitted model.

  • Score new individuals if any. To determine whether the treatment should be offered to a new individual, simply set the treatment variable to 1 for the individual, make a prediction for the outcome variable (or propensity if for classification), then a “0” to this individual, and make a prediction. Estimate the uplift by subtracting the predictions (1 minus 0).

The above procedure involves one model based on training data, two sets of predictions based on validation data, and one plot for the incremental gains from the fitted model against no model (random assignment) to evaluate the performance of the uplift model.

The uplift model can be fitted by using the R package “uplift”, which has been archived by CRAN but can be installed with the following code:

install.packages(c("devtools", "coin", "RItools", "penalized"))
devtools::install_version("uplift", version = "0.3.5", repos = "http://cran.us.r-project.org")

There are two functions available for fitting a uplift model: upliftRF() and upliftKNN(). The former is based on the random forest ensemble method for a classification tree model. When splitting a node, one criterion is the absolute difference in the uplift (average treatment effect) between the two leaves. The upliftKNN() is based on the K-nearest neighbor method. A nice reference is https://proceedings.mlr.press/v67/gutierrez17a/gutierrez17a.pdf.

Since the data have been partitioned (look at the “Partition” column) into Training and Validation parts, we can get the training and validation data by the following code:

train.df = subset(voter.df, Partition == "T")
valid.df = subset(voter.df, Partition == "V")

Now, we train the model:

# use upliftRF to apply a Random Forest (alternatively use upliftKNN() to apply kNN). 
up.fit <- upliftRF(formula = MOVED_A ~ AGE + NH_WHITE + COMM_PT + H_F1 + REG_DAYS + PR_PELIG + E_PELIG + POLITICALC + trt(MESSAGE_A),   ## Note the use of trt() function 
                   data = train.df,
                   split_method = "KL") 

Next, we make predictions. In the result of prediction, the first column is \(p(y = 1 | x, treatment)\), and the second column is \(p(y = 1 | x, control)\).

pred <- predict(up.fit, newdata = valid.df) # Use predict.upliftRF for help
head(pred, n= 20)
##       pr.y1_ct1 pr.y1_ct0
##  [1,]  0.401797  0.273982
##  [2,]  0.455316  0.302958
##  [3,]  0.367315  0.299466
##  [4,]  0.476186  0.659259
##  [5,]  0.281830  0.303141
##  [6,]  0.374843  0.321925
##  [7,]  0.437247  0.262783
##  [8,]  0.332277  0.365575
##  [9,]  0.430332  0.368656
## [10,]  0.402263  0.308286
## [11,]  0.401680  0.323465
## [12,]  0.298480  0.203154
## [13,]  0.402365  0.281436
## [14,]  0.437345  0.294346
## [15,]  0.407067  0.328698
## [16,]  0.335208  0.356568
## [17,]  0.562418  0.452207
## [18,]  0.407563  0.440729
## [19,]  0.419304  0.265948
## [20,]  0.463325  0.422585

We assess the performance by first sorting the validation data in descending order in terms of uplift values, and then splitting the data in groups of the same size:

perf <- performance(pr.y1_ct1 = pred[, 1], # the predicted probability Prob(y=1|treated, x)
                    pr.y1_ct0 = pred[, 2], # the predicted probability Prob(y=1|control, x).
                    y = valid.df$MOVED_A, # Actual observed response
                    ct = valid.df$MESSAGE_A, # Control = 0; treatment = 1
                    groups = 10  # Groups can be 5, 10, or 20 with 10 the default
                   )

perf
##       group n.ct1 n.ct0 n.y1_ct1 n.y1_ct0 r.y1_ct1 r.y1_ct0    uplift
##  [1,]     1   196   210       79       72 0.403061 0.342857  0.060204
##  [2,]     2   197   208       71       65 0.360406 0.312500  0.047906
##  [3,]     3   207   198       86       63 0.415459 0.318182  0.097277
##  [4,]     4   203   202       74       71 0.364532 0.351485  0.013047
##  [5,]     5   183   222       64       76 0.349727 0.342342  0.007384
##  [6,]     6   207   198       86       76 0.415459 0.383838  0.031621
##  [7,]     7   213   192       67       67 0.314554 0.348958 -0.034404
##  [8,]     8   189   216       81       68 0.428571 0.314815  0.113757
##  [9,]     9   186   219       81       82 0.435484 0.374429  0.061055
## [10,]    10   217   189       84       68 0.387097 0.359788  0.027308
## attr(,"class")
## [1] "performance"

Compute the Qini coefficient. A Qini coefficient is a natural generalization of the Gini coefficient. It is defined as the area between the actual incremental gains curve from the fitted model and the area under the diagonal corresponding to random targeting. The bigger the Qini, the better the performance of the model.

Q <- qini(perf) # Plot cumulative incremental gains against proportion of population targeted.
Q # The area between two curves
## $Qini
## [1] 4.629508e-05
## 
## $inc.gains
##  [1] 0.004485985 0.008375951 0.020747135 0.023217473 0.018248531 0.024290600
##  [7] 0.025204854 0.032639260 0.033257698 0.042193605
## 
## $random.inc.gains
##  [1] 0.004219361 0.008438721 0.012658082 0.016877442 0.021096803 0.025316163
##  [7] 0.029535524 0.033754884 0.037974245 0.042193605
title("Qini Plot")

To find the voters that a treatment should be given to, we first combine voter id’s and uplift scores, and then order the voters by uplift in descending order (“-”):

pred.augmented = data.frame(voter_ID = valid.df$VOTER_ID, 
                            MOVED_AD = valid.df$MOVED_AD,
                            pred, 
                            uplift = pred[ ,1] - pred[ ,2])


library(dplyr)
pred.sorted = arrange(pred.augmented, -uplift) 

head(pred.sorted, n = 50)
##    voter_ID MOVED_AD pr.y1_ct1 pr.y1_ct0   uplift
## 1    294606        Y  0.608362  0.237335 0.371027
## 2    294866        Y  0.591603  0.256524 0.335079
## 3    322599        N  0.576830  0.244361 0.332469
## 4    234208        N  0.645318  0.320476 0.324842
## 5    602686        N  0.639224  0.331289 0.307935
## 6    611492        N  0.622661  0.315395 0.307266
## 7    519653        Y  0.634286  0.338682 0.295604
## 8    574963        N  0.643551  0.350528 0.293023
## 9    527155        Y  0.616125  0.324182 0.291943
## 10   251673        Y  0.611066  0.321017 0.290049
## 11   553959        Y  0.482379  0.196421 0.285958
## 12   161243        N  0.635281  0.350136 0.285145
## 13   600874        N  0.633697  0.348978 0.284719
## 14    79121        N  0.464667  0.180677 0.283990
## 15   346055        N  0.582102  0.298588 0.283514
## 16    93662        Y  0.441987  0.164477 0.277510
## 17    23891        N  0.571802  0.294669 0.277133
## 18    21112        N  0.439726  0.164105 0.275621
## 19    97704        N  0.580691  0.305461 0.275230
## 20   173701        Y  0.439113  0.165567 0.273546
## 21   304977        N  0.547195  0.273705 0.273490
## 22   281886        Y  0.604442  0.331112 0.273330
## 23   319485        N  0.607264  0.339232 0.268032
## 24   554732        N  0.596058  0.328106 0.267952
## 25   176893        Y  0.530974  0.263906 0.267068
## 26   627087        N  0.439404  0.173929 0.265475
## 27   442617        N  0.497365  0.232529 0.264836
## 28   170289        Y  0.609934  0.345911 0.264023
## 29   243943        N  0.562229  0.298388 0.263841
## 30   306920        N  0.480737  0.218575 0.262162
## 31   206072        Y  0.526988  0.264856 0.262132
## 32   381086        N  0.472326  0.210249 0.262077
## 33    37965        N  0.424119  0.164545 0.259574
## 34   543619        N  0.452084  0.193140 0.258944
## 35   103046        N  0.468586  0.210604 0.257982
## 36   172561        Y  0.523392  0.265464 0.257928
## 37   615831        N  0.470630  0.214499 0.256131
## 38   581693        Y  0.507044  0.251743 0.255301
## 39   189406        N  0.408987  0.153941 0.255046
## 40    59202        N  0.482670  0.228134 0.254536
## 41   458334        N  0.501448  0.247512 0.253936
## 42   465717        N  0.430889  0.178603 0.252286
## 43   482606        N  0.411770  0.160451 0.251319
## 44   433115        N  0.468559  0.217262 0.251297
## 45   530395        Y  0.602411  0.351468 0.250943
## 46   113609        N  0.537674  0.287280 0.250394
## 47    48844        N  0.448278  0.197974 0.250304
## 48   247004        Y  0.535612  0.285976 0.249636
## 49   594801        N  0.432155  0.183210 0.248945
## 50    51510        Y  0.706282  0.458327 0.247955

How can we use the upliftKNN() function? A reference from pages 139-141 of the following reference: http://diposit.ub.edu/dspace/bitstream/2445/65123/1/Leo%20Guelman_PhD_THESIS.pdf

We will need to remove variables that are not needed:

vars = c("VOTER_ID","MOVED_AD","MOVED_A", "AGE", "NH_WHITE", "COMM_PT", "H_F1", "REG_DAYS", "PR_PELIG", "E_PELIG", "POLITICALC", "MESSAGE_A")

train.df = train.df[, vars]
valid.df = valid.df[, vars]

Train the model:

up.fit <- upliftKNN(train.df[, -c(1:2)], valid.df[, -c(1:2)], train.df$MOVED_A, train.df$MESSAGE_A, k = 10)
up.fit[1:10, ] # The output is already predicted values for the validation data, so no need to use predict() for predictions.
##         0   1
##  [1,] 0.2 0.3
##  [2,] 0.4 0.4
##  [3,] 0.3 0.8
##  [4,] 0.8 0.4
##  [5,] 0.4 0.5
##  [6,] 0.5 0.3
##  [7,] 0.5 0.3
##  [8,] 0.6 0.5
##  [9,] 0.6 0.4
## [10,] 0.3 0.3
pred = up.fit

perf <- performance(pr.y1_ct1 = pred[, 2], # the predicted probability Prob(y=1|treated, x)
                    pr.y1_ct0 = pred[, 1], # the predicted probability Prob(y=1|control, x).
                    y = valid.df$MOVED_A, # Actual observed response
                    ct = valid.df$MESSAGE_A # Control = 0; treatment = 1
                   )
## Warning in performance(pr.y1_ct1 = pred[, 2], pr.y1_ct0 = pred[, 1], y =
## valid.df$MOVED_A, : uplift: due to ties in uplift predictions, the number of
## groups is less than 10
Q <- qini(perf) # Plot cumulative incremental gains against proportion of population targeted.
Q # The area is about 0.105.
## $Qini
## [1] 0.001561766
## 
## $inc.gains
## [1] 0.01304908 0.02015112 0.01689323 0.01642879 0.02188849 0.02626044 0.03254813
## [8] 0.03979150 0.04219361
## 
## $random.inc.gains
## [1] 0.004688178 0.009376357 0.014064535 0.018752714 0.023440892 0.028129070
## [7] 0.032817249 0.037505427 0.042193605
title("Qini Plot")

pred.augmented = data.frame(voter_ID = valid.df$VOTER_ID, 
                            MOVED_AD = valid.df$MOVED_AD,
                            pred, 
                            uplift = pred[ ,2] - pred[ ,1])

pred.sorted = arrange(pred.augmented, -uplift) 


head(pred.sorted, n = 50)
##    voter_ID MOVED_AD        X0  X1    uplift
## 1    367732        N 0.1000000 0.8 0.7000000
## 2     71384        N 0.1000000 0.8 0.7000000
## 3    113972        N 0.0000000 0.7 0.7000000
## 4    378537        N 0.2000000 0.9 0.7000000
## 5    153130        N 0.0000000 0.7 0.7000000
## 6    224983        Y 0.2000000 0.9 0.7000000
## 7    466809        Y 0.2000000 0.9 0.7000000
## 8    245590        Y 0.2000000 0.9 0.7000000
## 9    314741        N 0.0000000 0.7 0.7000000
## 10    75773        Y 0.0000000 0.7 0.7000000
## 11   614725        N 0.3000000 1.0 0.7000000
## 12   529161        Y 0.2000000 0.9 0.7000000
## 13    38375        N 0.2000000 0.9 0.7000000
## 14   374355        N 0.0000000 0.7 0.7000000
## 15   301309        N 0.2000000 0.9 0.7000000
## 16   150696        Y 0.2000000 0.8 0.6000000
## 17   197979        N 0.3000000 0.9 0.6000000
## 18   204132        N 0.2000000 0.8 0.6000000
## 19   154621        Y 0.2000000 0.8 0.6000000
## 20   355659        Y 0.2000000 0.8 0.6000000
## 21   110534        N 0.2000000 0.8 0.6000000
## 22   113609        N 0.3000000 0.9 0.6000000
## 23   527856        N 0.2000000 0.8 0.6000000
## 24    94466        N 0.2000000 0.8 0.6000000
## 25   191956        N 0.2000000 0.8 0.6000000
## 26   232181        Y 0.2000000 0.8 0.6000000
## 27    21393        N 0.3000000 0.9 0.6000000
## 28   207671        N 0.2000000 0.8 0.6000000
## 29   463337        N 0.1000000 0.7 0.6000000
## 30    75499        N 0.1000000 0.7 0.6000000
## 31   508637        N 0.1000000 0.7 0.6000000
## 32   185030        Y 0.0000000 0.6 0.6000000
## 33   548388        N 0.1000000 0.7 0.6000000
## 34   593365        N 0.1000000 0.7 0.6000000
## 35   338865        N 0.1000000 0.7 0.6000000
## 36   251719        N 0.0000000 0.6 0.6000000
## 37   268402        N 0.1000000 0.7 0.6000000
## 38   560714        N 0.0000000 0.6 0.6000000
## 39    91996        Y 0.0000000 0.6 0.6000000
## 40   339566        Y 0.1000000 0.7 0.6000000
## 41   604063        N 0.1000000 0.7 0.6000000
## 42   417655        N 0.1000000 0.7 0.6000000
## 43   405641        Y 0.1000000 0.7 0.6000000
## 44   518512        Y 0.1000000 0.7 0.6000000
## 45    44893        N 0.1000000 0.7 0.6000000
## 46    88086        N 0.1000000 0.7 0.6000000
## 47   285354        Y 0.0000000 0.6 0.6000000
## 48   326667        N 0.2727273 0.8 0.5272727
## 49     9057        N 0.1818182 0.7 0.5181818
## 50   316168        N 0.3000000 0.8 0.5000000

How can we use the logistic regression model for uplift modeling?

We first train a logistic model with interaction terms between the treatment variable “MESSAGE_A” and each of the other predictors:

logistic.model <- glm(formula = MOVED_A ~ (AGE + NH_WHITE + COMM_PT + H_F1 + REG_DAYS + PR_PELIG + E_PELIG + POLITICALC )* MESSAGE_A, 
              data = train.df, 
              family = binomial) 

We then make predictions:

D1=valid.df
D1$MESSAGE_A = 1
predictedProbCT1 = predict(logistic.model, newdata = D1, type = "response")

D2=valid.df
D2$MESSAGE_A = 0
predictedProbCT0 = predict(logistic.model, newdata = D2, type = "response")

We assess the performance of the model:

perf <- performance(pr.y1_ct1 = predictedProbCT1, # the predicted probability Prob(y=1|treated, x)
                    pr.y1_ct0 = predictedProbCT0, # the predicted probability Prob(y=1|control, x).
                    y = valid.df$MOVED_A, # Actual observed response
                    ct = valid.df$MESSAGE_A # Control = 0; treatment = 1
                   )

Q <- qini(perf) # Plot cumulative incremental gains against proportion of population targeted.
Q # The area is about 0.105.
## $Qini
## [1] 0.008945335
## 
## $inc.gains
##  [1] 0.01269429 0.02152103 0.02086556 0.03174889 0.03519293 0.03211683
##  [7] 0.04005174 0.04563424 0.04373653 0.04219361
## 
## $random.inc.gains
##  [1] 0.004219361 0.008438721 0.012658082 0.016877442 0.021096803 0.025316163
##  [7] 0.029535524 0.033754884 0.037974245 0.042193605
title("Qini Plot")

Display the sorted data:

pred.augmented = data.frame(voter_ID = valid.df$VOTER_ID, 
                            MOVED_AD = valid.df$MOVED_AD,
                            predictedProbCT1, 
                            predictedProbCT0,
                            uplift = predictedProbCT1 - predictedProbCT0)

pred.sorted = arrange(pred.augmented, -uplift) 


pred.augmented %>% arrange(-uplift) -> pred.sorted
head(pred.sorted, n = 50)
##      voter_ID MOVED_AD predictedProbCT1 predictedProbCT0    uplift
## 1479   481831        Y        0.3983618        0.2387114 0.1596505
## 9564   601389        Y        0.4458234        0.2923603 0.1534631
## 4004   425448        N        0.7203102        0.5722875 0.1480227
## 8770    55742        Y        0.5235629        0.3903698 0.1331930
## 2138   346055        N        0.5197845        0.3868374 0.1329470
## 5570    97704        N        0.5193104        0.3865640 0.1327465
## 5079    14045        N        0.5190660        0.3864230 0.1326430
## 6385   581643        N        0.5172887        0.3847269 0.1325618
## 8979   203395        N        0.5220046        0.3894676 0.1325370
## 679    182856        N        0.5202615        0.3877859 0.1324756
## 3751   448535        N        0.5133808        0.3811414 0.1322394
## 8642   243943        N        0.5065284        0.3745602 0.1319682
## 8561   587535        N        0.5184949        0.3867662 0.1317287
## 2838   445966        N        0.5167545        0.3850909 0.1316636
## 3222   247004        Y        0.5005927        0.3692047 0.1313880
## 5045    28094        N        0.5113087        0.3799543 0.1313544
## 4407   374392        N        0.5134238        0.3825051 0.1309187
## 9054   152639        N        0.5114748        0.3807179 0.1307569
## 5451   306603        N        0.4956686        0.3651076 0.1305611
## 8658   249492        Y        0.4921507        0.3618179 0.1303328
## 1994   253184        N        0.5069230        0.3767808 0.1301423
## 6828    26225        N        0.5081423        0.3781434 0.1299990
## 1970   319123        N        0.4847314        0.3557124 0.1290190
## 7235    59731        N        0.4052628        0.2765993 0.1286635
## 7416    37931        N        0.4865644        0.3580371 0.1285273
## 5769   540833        N        0.4756623        0.3474497 0.1282126
## 1097   144441        N        0.5165308        0.3884160 0.1281148
## 264     22704        N        0.4977376        0.3702325 0.1275051
## 8371    85549        N        0.5747271        0.4473578 0.1273693
## 6792   130602        Y        0.4954128        0.3682567 0.1271561
## 9582   532931        Y        0.4716720        0.3446085 0.1270635
## 6830   578762        N        0.4479213        0.3208889 0.1270325
## 9249   185991        N        0.4871812        0.3603407 0.1268405
## 631    423317        N        0.4781651        0.3514138 0.1267513
## 6661   237377        N        0.4891371        0.3627474 0.1263897
## 1538   200951        Y        0.4685756        0.3422670 0.1263087
## 4958   368345        N        0.4887738        0.3625433 0.1262304
## 5479    37678        N        0.4805993        0.3547101 0.1258892
## 7656    47060        Y        0.4833316        0.3575346 0.1257970
## 2975   201204        N        0.4155955        0.2898717 0.1257238
## 2618   313618        Y        0.4152964        0.2897178 0.1255787
## 8483   321519        N        0.4152964        0.2897178 0.1255787
## 3867   164356        Y        0.4152784        0.2897085 0.1255699
## 7725   501738        N        0.4149362        0.2895324 0.1254038
## 5086   115088        N        0.4865200        0.3612784 0.1252416
## 8598   485603        N        0.4145077        0.2893119 0.1251957
## 533    585371        Y        0.3997867        0.2747703 0.1250164
## 65      72593        N        0.3983464        0.2734803 0.1248662
## 3126   394374        Y        0.4871146        0.3623552 0.1247594
## 3597   396594        N        0.3749962        0.2502562 0.1247400

Cluster Analysis (Chapter 15)

The learning methods (multiple regression, Naive Bayes, logistic regression, k-NN, decision tree, discriminant analysis, neural nets, SVM) are called supervised learning methods where we have a response variable. Cluster analysis is an unsupervised learning method with which there is no variable taking the role of a response variable. It is useful in marketing for market segmentation (customers are segmented based on demographic and transaction history information and different marketing strategies are tailored for different clusters).

The goal of cluster analysis is to segment the data into a set of homogeneous clusters of records for the purpose of generating insight. Once clusters are found, they can be handled/modeled separately.

There are two popular clustering approaches.

  • k-means clustering: used when we want a specified number of clusters. This method involves calculating distances between records (with nominal categorical variables dummified).

  • hierarchical clustering: used when we want to build a hierarchy of clusters (called dendrogram, a tree). This method involves calculating distances between records and between clusters.

An example of a nice use of clustering is in this article: https://arxiv.org/pdf/1805.02501.pdf. The R package is “basketballAnalyzeR”.

Hierarchical (Agglomerative) Clustering

This method involves calculating distances between records and clusters.

Measuring Distance between Two Records

Records with nominal categorical variables dummified can be viewed as points in a high-dimensional space. The distance (or dissimilarity) between two points in a \(p-\)dimensional space can be defined using different metrics, including the Euclidean distance and statistical distance (also called Mahalanobis distance). Among these distance measures, the Gower distance can handle mixed data (continuous and binary). A tutorial on Gower distance: https://jamesmccaffrey.wordpress.com/2020/04/21/example-of-calculating-the-gower-distance/

We will use the package “StatMatch” to calculate the Gower distance.

An example:

# Distances between any two of the first 5 records in the "mtcars" data frame
d = dist(mtcars[1:5, ], method = "euclidean") 
d
##                     Mazda RX4 Mazda RX4 Wag  Datsun 710 Hornet 4 Drive
## Mazda RX4 Wag       0.6153251                                         
## Datsun 710         54.9086059    54.8915169                           
## Hornet 4 Drive     98.1125212    98.0958939 150.9935191               
## Hornet Sportabout 210.3374396   210.3358546 265.0831615    121.0297564
# Distances between any two of the first 5 records in the "iris" data frame
StatMatch::gower.dist(iris[1:5, ])  # Nominal variables are automatically converted to binary dummies
##            [,1]      [,2]      [,3]      [,4]       [,5]
## [1,] 0.00000000 0.2466667 0.3600000 0.4333333 0.07333333
## [2,] 0.24666667 0.0000000 0.2466667 0.2533333 0.24000000
## [3,] 0.36000000 0.2466667 0.0000000 0.2733333 0.35333333
## [4,] 0.43333333 0.2533333 0.2733333 0.0000000 0.42666667
## [5,] 0.07333333 0.2400000 0.3533333 0.4266667 0.00000000

It is suggested to normalize numerical variables before calculating distances.

Measuring Distance between Two Clusters

A cluster is the set of records that are similar. The distance between two clusters can be defined in different ways.

  • Minimum distance: The distance between the two clusters is the smallest distance between a record in the first cluster and a record in the second cluster. Clustering using this distance is called single-linkage clustering. The distance between two countries use this distance.

  • Maximum distance: The distance between the two clusters is the largest distance between a record in the first cluster and a record in the second cluster. Clustering using this distance is called complete-linkage clustering.

  • Average distance: The distance between the two clusters is the average distance between records in the first cluster and records in the second cluster. Clustering using this distance is called average-linkage clustering.

  • Centroid distance: The distance between the two clusters is the distance between the center of the first cluster and the center of the second cluster. The centroid of a cluster is the average of all records in that cluster. Clustering using this distance is called centroid-linkage clustering.

Different distances have advantages and disadvantages.

How Hierarchical (Agglomerative) Clustering Is Done

Suppose there are \(n\) records in the data.

The steps are:

  1. Start with \(n\) clusters, each record being a cluster.

  2. The two closet records (clusters) are merged into one cluster.

  3. At every step, the two clusters with the smallest distance are merged.

A dendrogram is a tree-like diagram that summarizes the process of clustering.

In addition to the aforementioned distances that can be used when constructing a dendrogram, the Ward’s method (minimum variation method) can be used when joining records and clusters.

The total variation (called ESS for Error Sum of Squares or WSS for Within Sum of Squares) of a cluster is the sum of the squared deviations from each observation and the cluster centroid. It is calculated by \[\Sigma_x [d(x, c)]^2 \] where \(x\) is any observation in the cluster, \(c\) is the centroid \(c\) of the cluster, and \(d(x,c)\) is the Euclidean distance between \(x\) and \(c\). The total variation of all clusters is the sum of all Within Sum of Squares. This method also starts with each record as a cluster, each with a zero ESS. Next, pick two clusters to merge so that the total ESS is the smallest. Continue the process until there is only one cluster.

USArrests.norm = scale(USArrests)


hc1 <- hclust(dist(USArrests.norm, method = "euclidean"), 
             method = "average")
plot(hc1)

hc2 <- hclust(dist(USArrests.norm, method = "euclidean"), 
             method = "single")
plot(hc2)

hc3 <- hclust(dist(USArrests.norm, method = "euclidean"), 
              method = "complete"    # The default method
             )
plot(hc3)

hc4 <- hclust(dist(USArrests.norm, method = "euclidean"), 
             method = "centroid")
plot(hc4)

Cluster = cutree(hc4, k = 4) # Generate 4 clusters using a dendrogram based on the centroid method.

# We can add a cluster column to the original data if we want to analyze the data further.
D = data.frame(USArrests, Cluster)
D

set.seed(1234)
hc <- hclust(dist(iris[sample(150, 3),1:4], method = "euclidean"), 
              method = "ward.D")
plot(hc)

For different cutoffs of the distance used, the clusters are different. For example, when using a cutoff 3.5 with the dendragram obtained using the complete linkage method, we see 4 clusters (the horizontal line drawn through 3.5 crosses 4 branches of the tree).

The Ward’s method is demonstrated below:

set.seed(1234)
D = iris[sample(1:150, 6),1:4] # For clarity, use only 6 randomly selected species to do clustering

hc5 <- hclust(dist(D, method = "euclidean"), 
              method = "ward.D")
plot(hc5)

The Ward’s procedure shows that the procedure starts with 6 species: 18, 91, 92, 93, 126, and 149. At the beginning, we have 6 clusters: {18}, {91}, {92}, {93}, {126}, and {149}. Which 2 clusters should be merged in order to produce 5 clusters with the total ESS being the smallest? It turns out that merging clusters {91} and {93} produces the smallest total ESS. After merging the two clusters, we have 5 clusters: {18}, {92}, {91, 93}, {126}, and {149}. Next, which two clusters should be merged in order to produce the smallest total ESS? It turns out that merging clusters {92} and {91, 93} produces the smallest total ESS. Similar interpretation goes on and on.

Non-Hierarchical Clustering: the k-Means Algorithm

If we want to divide observations into a predetermined number of non-overlapping clusters so that clusters are as homogeneous as possible according to a criterion, the k-means algorithm can be used. A commonly used criterion is the sum of squared distances of observations from their cluster centroid. A nice reference: https://uc-r.github.io/kmeans_clustering and an app: https://shiny.rstudio.com/gallery/kmeans-example.html

Here are the steps of the algorithm:

Step 1: Choose the number of clusters k.

Step 2: Select k observations randomly from the data as centers (called centroids) of the clusters.

Step 3: Allocate each observation to the cluster with the nearest centroid.

Step 4: Re-calculate the centroids of newly formed k clusters.

Step 5: Repeat steps 3 and 4.

To perform a cluster analysis in R, generally, the data should be prepared as follows:

  • Any missing value in the data must be removed or estimated.

  • The data must be normalized (usually standardized) to make variables comparable. Recall that, standardization consists of transforming the variables such that they have mean zero and standard deviation one.

USArrests.norm = scale(USArrests)


# k-means with the USA arrests data
k2 <- kmeans(USArrests.norm, centers = 2)

D = USArrests
D$Cluster = k2$cluster

k3 <- kmeans(USArrests.norm, centers = 3)
k4 <- kmeans(USArrests.norm, centers = 4)
k5 <- kmeans(USArrests.norm, centers = 5)

# plots for different k. Need package "factoextra"
library(factoextra) # The package allows a nice visualization of clusters
## Welcome! Want to learn more? See two factoextra-related books at https://goo.gl/ve3WBa
p1 <- fviz_cluster(k2, geom = "point", data = USArrests.norm) + ggtitle("k = 2")
p2 <- fviz_cluster(k3, geom = "point",  data = USArrests.norm) + ggtitle("k = 3")
p3 <- fviz_cluster(k4, geom = "point",  data = USArrests.norm) + ggtitle("k = 4")
p4 <- fviz_cluster(k5, geom = "point",  data = USArrests.norm) + ggtitle("k = 5")

gridExtra::grid.arrange(p1, p2, p3, p4, nrow = 2) # need package gridExtra

Let’s focus on the 3-cluster case:

k = 3
km <- kmeans(USArrests.norm, centers = k)

To get the detailed information stored in km, do

km$cluster
km$centers
km$totss          # The total sum of squares
km$withinss       # A vector of within-cluster sum of squares, one per cluster
km$tot.withinss   # Total within-cluster sum of squares
km$betweenss      # The between-cluster sum of squares, i.e. totss minus tot.withinss
km$size           # The number of points in each cluster
km$iter           # The number of (outer) iterations

We plot some results:

Centers = km$centers # Centroids
Centers
##       Murder    Assault   UrbanPop       Rape
## 1 -0.9615407 -1.1066010 -0.9301069 -0.9667633
## 2  1.0049340  1.0138274  0.1975853  0.8469650
## 3 -0.4469795 -0.3465138  0.4788049 -0.2571398
# plot an empty scatter plot; more content added later
plot(0, # Only one point on the plot
     type = "l", # draw a line. Since there is only one point, there is no line shown.
     xaxt = 'n', # Remove x-axis
     xlab = "", # No y-label
     ylab = "", # No y-label
     ylim = c(min(km$centers), max(km$centers)), # Set y-limits
     xlim = c(0, ncol(USArrests.norm)) # Set x-limits
    ) 

# label x-axes
axis(side = 1, # side = 1 (bottom), 2 (left), 3 (top), or 4 (right), with "1" = x-axis
     at = 1:ncol(USArrests), # place labels at x = 1, 2, 3, ...
     labels = names(USArrests), # Use variable names as labels at the 8 locations
     las = 2
    ) 

# plot k centroids, one for each cluster
for (i in 1:k) {
  lines(Centers[i, ], # Each cluster corresponds to a set of connected points
        lty = i, # Different clusters correspond to different line types
        lwd = 2, # line width set to 2; default is 1.
        col = i) # Color clusters differently
}
  
# Name clusters
text(x = 0.5, y = km$centers[, 1], labels = paste("Cluster", 1:k), col = 1:3)

# Display observations along with clusters
D = data.frame(USArrests, cluster = factor(km$cluster), 
          state = row.names(USArrests))

ggplot(D, aes(UrbanPop, Murder, color = cluster, label = state)) +
  geom_text()

How to determine the optimal number of clusters? There are three popular methods:

  • Elbow method

  • Silhouette method

  • Gap statistic

A nice video: https://www.youtube.com/watch?v=QXOkPvFM6NU

The elbow method chooses an optimal \(k\) by minimizing the total intra-cluster variation (known as total within-cluster variation or total within-cluster sum of squares). The steps of this method are:

  1. Implement k-means clustering for different values of k. For instance, by varying k from 1 to 10 clusters.

  2. For each k, calculate the total within-cluster sum of square (wss).

  3. Plot the curve of wss according to the number of clusters k.

  4. The location of a bend (knee) in the plot is generally considered as an indicator of the appropriate number of clusters.

USArrests.norm = scale(USArrests)

# function to compute total within-cluster sum of square 
wss <- function(k) {
  km = kmeans(USArrests.norm, k )
  return(km$tot.withinss/k) # Return average within-cluster sum of squared distances 
}

# Compute and plot wss for k = 2 to k = 15
k.values <- 2:15

# extract wss for 2-15 clusters
wss_values <- sapply(k.values, wss)

plot(k.values, wss_values,
       type="b", pch = 19, frame = FALSE, 
       xlab="Number of clusters K",
       ylab="Mean within-cluster sum of squared distances")

The plot shows that, moving from 1 to 2 tightens clusters considerably(reflected by the large reduction in withing-cluster dispersion), and so does moving from 2 to 3 and even to 4. Adding more clusters beyond 4 brings less improvement to cluster homogeneity.

The Silhouette method chooses an optimal \(k\) by maximizing the average silhouette width over a range of possible values for k. The silhouette width of observation \(i\) is defined as \[s_i = ( b_i - a_i ) / max( a_i, b_i )\] where \(a_i\) is the average dissimilarity between \(i\) and all other points of the cluster to which i belong and \(b_i\) is the dissimilarity between \(i\) and its nearest “neighbor” cluster.

The method determines how well each object lies within its cluster. A high average silhouette width indicates a good clustering.

The gap statistic proposed by R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001) can be applied to any clustering method (i.e. K-means clustering, hierarchical clustering). The gap statistic compares the total intracluster variation for different values of k with their expected values under null reference distribution of the data (i.e. a distribution with no obvious clustering). The optimal value of k is the one that maximize the gap statistic.

Different criteria of choosing k can be demonstrated by calling built-in functions:

USArrests.norm = scale(USArrests)

set.seed(123)

fviz_nbclust(USArrests.norm, kmeans, method = "wss")

fviz_nbclust(USArrests.norm, kmeans, method = "silhouette")

fviz_nbclust(USArrests.norm, kmeans, method = "gap_stat")

To evaluate the cluster validity, we can use the ratio of the between-cluster sum of squares for a given k to the total sum of squares of the data. If the ratio is large, we have well-separated clusters. This ratio plays the role of the \(R^2\) in one-way analysis of variance, so way can just call the ratio \(R^2\). Like \(R^2\), this ratio increases as the number of clusters increases, so we may need an adjusted ratio similar to the adjusted \(R^2\) in one-way analysis of variance. \(R^2_{adj}=1-(1-R^2)\frac{n-1}{n-k}\), where \(k\) is the number of clusters.

USArrests.norm = scale(USArrests)

km3 = kmeans(USArrests.norm, centers = 3, nstart = 10 )

km3$betweenss/km3$totss # Something like R-square
## [1] 0.6003915

When data include categorical variables, the function VarSelCluster() from the VarSelLCM package can do k-means clustering. The result is in the “partitions” slot.

library(VarSelLCM) # Don't put this in the setup chunk. It has a predict() function
## 
## Attaching package: 'VarSelLCM'
## The following object is masked from 'package:penalized':
## 
##     predict
## The following object is masked from 'package:stats':
## 
##     predict
K=VarSelCluster(iris, 3, vbleSelec = FALSE) # Variable selection and clustering

fitted(K) # Estimated Clusters
##   [1] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
##  [38] 3 3 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
##  [75] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 2 1 1 1 1
## [112] 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [149] 1 1
fitted(K, type="probability") # Estimated probabilities of classification
##              class-1       class-2 class-3
##   [1,]  0.000000e+00 2.283305e-197       1
##   [2,]  0.000000e+00 2.650690e-196       1
##   [3,]  0.000000e+00 1.969478e-197       1
##   [4,]  0.000000e+00 2.661758e-196       1
##   [5,]  0.000000e+00 7.788761e-198       1
##   [6,] 4.446591e-323 2.012277e-193       1
##   [7,]  0.000000e+00 1.860033e-196       1
##   [8,]  0.000000e+00 1.094488e-196       1
##   [9,]  0.000000e+00 1.846846e-196       1
##  [10,]  0.000000e+00 6.452739e-197       1
##  [11,]  0.000000e+00 7.037607e-197       1
##  [12,]  0.000000e+00 2.830919e-196       1
##  [13,]  0.000000e+00 2.506704e-197       1
##  [14,]  0.000000e+00 1.800316e-198       1
##  [15,]  0.000000e+00 4.122939e-198       1
##  [16,]  0.000000e+00 4.510032e-196       1
##  [17,]  0.000000e+00 7.834492e-196       1
##  [18,]  0.000000e+00 3.598007e-196       1
##  [19,] 2.470328e-323 8.478344e-194       1
##  [20,]  0.000000e+00 1.324287e-196       1
##  [21,]  0.000000e+00 1.737875e-194       1
##  [22,]  0.000000e+00 8.853291e-195       1
##  [23,]  0.000000e+00 3.806028e-199       1
##  [24,] 1.166489e-320 2.524103e-190       1
##  [25,] 9.881313e-324 1.635343e-193       1
##  [26,]  0.000000e+00 5.816290e-195       1
##  [27,] 9.881313e-324 2.456855e-193       1
##  [28,]  0.000000e+00 1.182867e-196       1
##  [29,]  0.000000e+00 6.792283e-197       1
##  [30,]  0.000000e+00 8.141870e-196       1
##  [31,]  0.000000e+00 1.846216e-195       1
##  [32,] 1.482197e-323 2.922725e-193       1
##  [33,]  0.000000e+00 1.261634e-199       1
##  [34,]  0.000000e+00 5.288309e-199       1
##  [35,]  0.000000e+00 5.213358e-196       1
##  [36,]  0.000000e+00 2.294673e-197       1
##  [37,]  0.000000e+00 5.844534e-197       1
##  [38,]  0.000000e+00 7.041194e-199       1
##  [39,]  0.000000e+00 4.198300e-197       1
##  [40,]  0.000000e+00 1.567566e-196       1
##  [41,]  0.000000e+00 9.738964e-197       1
##  [42,]  0.000000e+00 1.601570e-194       1
##  [43,]  0.000000e+00 1.317647e-197       1
##  [44,] 5.339862e-320 8.672173e-190       1
##  [45,] 1.971322e-321 1.089767e-191       1
##  [46,]  0.000000e+00 3.191351e-195       1
##  [47,]  0.000000e+00 3.895214e-197       1
##  [48,]  0.000000e+00 4.247520e-197       1
##  [49,]  0.000000e+00 4.292515e-197       1
##  [50,]  0.000000e+00 6.096512e-197       1
##  [51,]  3.162386e-10  1.000000e+00       0
##  [52,]  7.408459e-11  1.000000e+00       0
##  [53,]  1.630184e-09  1.000000e+00       0
##  [54,]  1.455870e-14  1.000000e+00       0
##  [55,]  5.938248e-11  1.000000e+00       0
##  [56,]  7.534428e-13  1.000000e+00       0
##  [57,]  6.759279e-10  1.000000e+00       0
##  [58,]  4.721137e-17  1.000000e+00       0
##  [59,]  1.123876e-11  1.000000e+00       0
##  [60,]  5.032714e-14  1.000000e+00       0
##  [61,]  3.483621e-17  1.000000e+00       0
##  [62,]  3.864752e-12  1.000000e+00       0
##  [63,]  2.368964e-15  1.000000e+00       0
##  [64,]  1.582454e-11  1.000000e+00       0
##  [65,]  1.620774e-14  1.000000e+00       0
##  [66,]  2.522253e-11  1.000000e+00       0
##  [67,]  9.005954e-12  1.000000e+00       0
##  [68,]  9.651077e-15  1.000000e+00       0
##  [69,]  3.567083e-12  1.000000e+00       0
##  [70,]  3.429342e-15  1.000000e+00       0
##  [71,]  6.318792e-09  1.000000e+00       0
##  [72,]  1.699889e-13  1.000000e+00       0
##  [73,]  7.664549e-11  1.000000e+00       0
##  [74,]  1.753666e-12  1.000000e+00       0
##  [75,]  1.637689e-12  1.000000e+00       0
##  [76,]  1.532429e-11  1.000000e+00       0
##  [77,]  1.142953e-10  1.000000e+00       0
##  [78,]  1.680173e-08  1.000000e+00       0
##  [79,]  1.498851e-11  1.000000e+00       0
##  [80,]  5.315372e-16  1.000000e+00       0
##  [81,]  1.462282e-15  1.000000e+00       0
##  [82,]  4.806410e-16  1.000000e+00       0
##  [83,]  1.876045e-14  1.000000e+00       0
##  [84,]  6.755505e-10  1.000000e+00       0
##  [85,]  6.527532e-12  1.000000e+00       0
##  [86,]  1.631208e-10  1.000000e+00       0
##  [87,]  3.373558e-10  1.000000e+00       0
##  [88,]  4.243118e-13  1.000000e+00       0
##  [89,]  1.637940e-13  1.000000e+00       0
##  [90,]  2.527156e-14  1.000000e+00       0
##  [91,]  7.929540e-14  1.000000e+00       0
##  [92,]  1.226918e-11  1.000000e+00       0
##  [93,]  2.211483e-14  1.000000e+00       0
##  [94,]  3.995226e-17  1.000000e+00       0
##  [95,]  1.216335e-13  1.000000e+00       0
##  [96,]  1.209214e-13  1.000000e+00       0
##  [97,]  2.390229e-13  1.000000e+00       0
##  [98,]  1.029326e-12  1.000000e+00       0
##  [99,]  5.675228e-17  1.000000e+00       0
## [100,]  1.196058e-13  1.000000e+00       0
## [101,]  1.000000e+00  5.884039e-12       0
## [102,]  9.984307e-01  1.569303e-03       0
## [103,]  1.000000e+00  9.250663e-09       0
## [104,]  9.999063e-01  9.370802e-05       0
## [105,]  1.000000e+00  1.368236e-08       0
## [106,]  1.000000e+00  1.809940e-11       0
## [107,]  1.842352e-01  8.157648e-01       0
## [108,]  9.999999e-01  7.268871e-08       0
## [109,]  9.999711e-01  2.887801e-05       0
## [110,]  1.000000e+00  1.428963e-13       0
## [111,]  9.999809e-01  1.910027e-05       0
## [112,]  9.998559e-01  1.440862e-04       0
## [113,]  9.999998e-01  2.457994e-07       0
## [114,]  9.989750e-01  1.025004e-03       0
## [115,]  9.999999e-01  9.760746e-08       0
## [116,]  1.000000e+00  2.822157e-08       0
## [117,]  9.999182e-01  8.177171e-05       0
## [118,]  1.000000e+00  1.745195e-13       0
## [119,]  1.000000e+00  8.529565e-14       0
## [120,]  3.386894e-01  6.613106e-01       0
## [121,]  1.000000e+00  7.371224e-10       0
## [122,]  9.990534e-01  9.465720e-04       0
## [123,]  1.000000e+00  6.111102e-11       0
## [124,]  9.924407e-01  7.559317e-03       0
## [125,]  1.000000e+00  4.971917e-08       0
## [126,]  9.999997e-01  3.446610e-07       0
## [127,]  9.877198e-01  1.228019e-02       0
## [128,]  9.943660e-01  5.633994e-03       0
## [129,]  9.999994e-01  6.274905e-07       0
## [130,]  9.999684e-01  3.159519e-05       0
## [131,]  9.999999e-01  5.188730e-08       0
## [132,]  1.000000e+00  2.592101e-11       0
## [133,]  9.999999e-01  9.511921e-08       0
## [134,]  8.947754e-01  1.052246e-01       0
## [135,]  9.467149e-01  5.328512e-02       0
## [136,]  1.000000e+00  7.428608e-12       0
## [137,]  1.000000e+00  4.998296e-10       0
## [138,]  9.999174e-01  8.264487e-05       0
## [139,]  9.884261e-01  1.157392e-02       0
## [140,]  9.999997e-01  2.612249e-07       0
## [141,]  1.000000e+00  3.582950e-10       0
## [142,]  1.000000e+00  2.864690e-08       0
## [143,]  9.984307e-01  1.569303e-03       0
## [144,]  1.000000e+00  2.868158e-10       0
## [145,]  1.000000e+00  1.380605e-11       0
## [146,]  1.000000e+00  3.648650e-08       0
## [147,]  9.984100e-01  1.589959e-03       0
## [148,]  9.999824e-01  1.758054e-05       0
## [149,]  1.000000e+00  1.643127e-08       0
## [150,]  9.970555e-01  2.944548e-03       0
plot(K, type="probs-class") # Histogram of the probabilities of misclassification per cluster
## Warning: Use of `tmp$probs` is discouraged.
## ℹ Use `probs` instead.

plot(K) # Shows the discriminative power of each variable