We will work with the library neuralnet here. This is an old library that is just enough to get a bit familiar with deep learning. If you are interested in the subject, you need to take a special course or even several courses (e.g., Deep Learning sequence on Coursera has 5 courses).
There are two modern frameworks for deep learning available in R. One is keras (it his a high level interfaces for tensorflow developed by Google). keras is definitely worth getting familiar with. The only problem is that keras for R is not really written in R, but rather it just allows you to call Python functions with R syntax. It means that in order for it to work, you will first need to install Python + Anaconda + TensorFlow + Keras for Python. Only after installing the Python bundle, you will be able to work in keras and chances are that there will be some issues with installation.
Another popular framework for deep learning called PyTorch (PyTorch sits on top of Torch written in C) was released for R last year as a native R implementation, i.e., it will work in R without having to install Python. It is easier to install than keras for R, but there may be problems to. Anyway, this new library is worth exploring and you can do it in your project:
The aim of this lab is to acquire basic familiarity with deep learning.
By the end of this lab session, students should be able to
Plot greyscale images in R
Arrange several plots on a grid
Train artificial neural networks in R
Set training controls and hyperparameters for neural networks.
Please run the R chunks one by one, look at the output and make sure that you understand how it is produced. There will be questions that either require a short answer - then you type your answer right in this document - or modifying R codes - then you modify the R codes here. In either case, you can discuss your work with the lab instructor.
The original raw dataset can be found here:
It consists of 60000 training and 10000 grayscale images, each labelled by one of the following classes:
0 T-shirt/top 1 Trouser 2 Pullover 3 Dress 4 Coat 5 Sandal 6 Shirt 7 Sneaker 8 Bag 9 Ankle boot
First we will load the data into R. Below are the data dimensions and variables.
library(tidyverse) # for manipulation with data
library(caret) # for machine learning, including KNN
library(neuralnet)
library(fastDummies)
# library(keras) # for artificial neural networks
library(gridExtra) # for arranging several plots
fashion_mnist_train <- read_csv("fashion-mnist_train.csv")
fashion_mnist_test <- read_csv("fashion-mnist_test.csv")# c(train_images, train_labels) %<-% fashion_mnist$train
# c(test_images, test_labels) %<-% fashion_mnist$test
# Labels need to be stored separately as they are not a part of the data
class_names = c('T_shirt_top',
                'Trouser',
                'Pullover',
                'Dress',
                'Coat', 
                'Sandal',
                'Shirt',
                'Sneaker',
                'Bag',
                'Ankle_boot')
cat("Training data dim =", dim(fashion_mnist_train), "\n")## Training data dim = 60000 785cat("Test data dim =", dim(fashion_mnist_test), "\n")## Test data dim = 10000 785Here is a sample of how the first image represented in the raw data:
matrix(fashion_mnist_train[1, -1], ncol = 28)[10:15, 10:15]##      [,1] [,2] [,3] [,4] [,5] [,6]
## [1,] 208  217  193  158  184  230 
## [2,] 203  230  139  90   156  229 
## [3,] 215  222  136  103  150  221 
## [4,] 210  215  195  186  193  213 
## [5,] 209  224  147  138  170  224 
## [6,] 209  233  156  100  164  233And this is what labels look like:
head(fashion_mnist_train$label)## [1] 2 9 6 0 3 4The dataset is perfectly balanced:
table(fashion_mnist_train$label)## 
##    0    1    2    3    4    5    6    7    8    9 
## 6000 6000 6000 6000 6000 6000 6000 6000 6000 6000fashion_mnist_train %>% select(-label) %>% slice(1) %>% as.numeric %>% matrix(ncol = 28) %>% as.data.frame()Below is an actual image:
plot_fashion_mnist_image <- function(i, dataset = fashion_mnist_train) {
  image <- dataset %>% select(-label) %>% slice(i) %>% 
    as.numeric %>% matrix(ncol = 28) %>% t %>% as.data.frame()
  colnames(image) <- seq_len(ncol(image))
  image$y <- seq_len(nrow(image))
  image <- pivot_longer(image, -y, names_to = "x", values_to = "value")
  image$x <- as.integer(image$x)
  
  ggplot(image, aes(x = x, y = y, fill = value)) +
    geom_tile() +
    scale_fill_gradient(low = "white", high = "black", na.value = NA) +
    scale_y_reverse() +
    theme_minimal() +
    theme(panel.grid = element_blank(), axis.text.y=element_blank(),
          axis.text.x=element_blank(), aspect.ratio = 1, 
          legend.position = "none") +
    xlab("") +
    ylab("")
}
plot_fashion_mnist_image(1)And here is how we can plot several images on the same plot:
1:25 %>%
  lapply(plot_fashion_mnist_image) %>%
  marrangeGrob(nrow = 5, ncol = 5)Change the function for plotting images so that it will print the image title with the class label. Plot the first 16 images with the class labels.
plot_fashion_mnist_image <- function(i, dataset = fashion_mnist_train) {
  image <- dataset %>% select(-label) %>% slice(i) %>% 
    as.numeric %>% matrix(ncol = 28) %>% t %>% as.data.frame()
  colnames(image) <- seq_len(ncol(image))
  image$y <- seq_len(nrow(image))
  image <- pivot_longer(image, -y, names_to = "x", values_to = "value")
  image$x <- as.integer(image$x)
  ggplot(image, aes(x = x, y = y, fill = value)) +
    geom_tile() +
    scale_fill_gradient(low = "white", high = "black", na.value = NA) +
    scale_y_reverse() +
    theme_minimal() +
    theme(panel.grid = element_blank(), axis.text.y=element_blank(),
          axis.text.x=element_blank(), aspect.ratio = 1, 
          legend.position = "none") +
    xlab("") +
    ylab("") + ggtitle(class_names[dataset$label[i]+1])
}
1:16 %>%
  lapply(plot_fashion_mnist_image) %>%
  marrangeGrob(nrow = 4, ncol = 4)Here, we will train a random forest to get some idea of how a familiar model can predict the labels.
First, we convert our numeric labels to categorical with the function recode or, in this case, recode_factor. Its argument is a named vector whose names are old labels and whose entries are new labels.
names(class_names) <- 0:9
train_labels_char <- fashion_mnist_train$label %>%
  as.character %>%
  recode_factor(!!!class_names) 
test_labels_char <- fashion_mnist_test$label %>%
  as.character %>%
  recode_factor(!!!class_names) 
head(train_labels_char)## [1] Pullover    Ankle_boot  Shirt       T_shirt_top Dress       Coat       
## 10 Levels: T_shirt_top Trouser Pullover Dress Coat Sandal Shirt Sneaker ... Ankle_bootWe do not understand the purpose of !!! - it is like that in the manual. If someone can figure it out and explain to all the students via WhatsApp, we will be grateful.
Now we will convert our data to data frame format. Below is a sample of the training data:
train_df <- fashion_mnist_train %>%
  mutate(label = train_labels_char)
test_df <- fashion_mnist_test %>%
  mutate(label = test_labels_char)
train_df %>%
  select(pixel305:pixel310, label) %>%
  sample_n(10)Now we are ready to train familiar models. We will do random forest since it is natively designed for multi-class classification. Since the dataset is quite large, we will skip proper hyperparameter tuning, but you can try to experiment with hyperparameters at home to see how well you can predict test labels.
set.seed(199)
rfGrid <- expand.grid(mtry = c(30), 
                      min.node.size = c(40),
                      splitrule = "gini")
mod_rf <- train(label ~ . , data = train_df, method = "ranger",
                num.trees = 100,
                importance = 'impurity',
                tuneGrid = rfGrid,
                trControl = trainControl("oob"))
mod_rf## Random Forest 
## 
## 60000 samples
##   784 predictor
##    10 classes: 'T_shirt_top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle_boot' 
## 
## No pre-processing
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.8730167  0.8589074
## 
## Tuning parameter 'mtry' was held constant at a value of 30
## Tuning
##  parameter 'splitrule' was held constant at a value of gini
## Tuning
##  parameter 'min.node.size' was held constant at a value of 40Which classes do you think are the hardest for the model to predict? Which ones will it confuse the most? Think about it and then print the test confusion matrix.
Answer the model should have difficulty distinguishing items that look similar. It will probably have hard time to distinguish a shirt, a t-shirt, and a pullover. Maybe, ankle boots vs sandals with high heels.
Looking at the confusion matrix, we see that, indeed, “shirt” is the hardest label to predict, but distinguishing sandals from high heel boots doesn’t seem to be hard - probably, the dataset does not have many shoes with high heels.
mod_rf %>%
  predict(test_df, type = "raw") %>%
  confusionMatrix(test_df$label)## Confusion Matrix and Statistics
## 
##              Reference
## Prediction    T_shirt_top Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag
##   T_shirt_top         842       2        7    19    0      0   171       0   1
##   Trouser               0     966        1     9    1      0     1       0   1
##   Pullover             11       6      802    14   59      0    94       0   7
##   Dress                37      17       11   928   32      0    31       0   0
##   Coat                  2       1      115    19  864      0    76       0   3
##   Sandal                2       1        0     0    0    946     0      20   2
##   Shirt                94       6       53    10   41      0   609       0  10
##   Sneaker               0       0        0     0    0     37     0     925   2
##   Bag                  12       1       11     1    3      5    18       0 973
##   Ankle_boot            0       0        0     0    0     12     0      55   1
##              Reference
## Prediction    Ankle_boot
##   T_shirt_top          0
##   Trouser              0
##   Pullover             0
##   Dress                0
##   Coat                 0
##   Sandal               7
##   Shirt                1
##   Sneaker             42
##   Bag                  3
##   Ankle_boot         947
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8802          
##                  95% CI : (0.8737, 0.8865)
##     No Information Rate : 0.1             
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.8669          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: T_shirt_top Class: Trouser Class: Pullover
## Sensitivity                      0.8420         0.9660          0.8020
## Specificity                      0.9778         0.9986          0.9788
## Pos Pred Value                   0.8081         0.9867          0.8077
## Neg Pred Value                   0.9824         0.9962          0.9780
## Prevalence                       0.1000         0.1000          0.1000
## Detection Rate                   0.0842         0.0966          0.0802
## Detection Prevalence             0.1042         0.0979          0.0993
## Balanced Accuracy                0.9099         0.9823          0.8904
##                      Class: Dress Class: Coat Class: Sandal Class: Shirt
## Sensitivity                0.9280      0.8640        0.9460       0.6090
## Specificity                0.9858      0.9760        0.9964       0.9761
## Pos Pred Value             0.8788      0.8000        0.9673       0.7391
## Neg Pred Value             0.9919      0.9848        0.9940       0.9574
## Prevalence                 0.1000      0.1000        0.1000       0.1000
## Detection Rate             0.0928      0.0864        0.0946       0.0609
## Detection Prevalence       0.1056      0.1080        0.0978       0.0824
## Balanced Accuracy          0.9569      0.9200        0.9712       0.7926
##                      Class: Sneaker Class: Bag Class: Ankle_boot
## Sensitivity                  0.9250     0.9730            0.9470
## Specificity                  0.9910     0.9940            0.9924
## Pos Pred Value               0.9195     0.9474            0.9330
## Neg Pred Value               0.9917     0.9970            0.9941
## Prevalence                   0.1000     0.1000            0.1000
## Detection Rate               0.0925     0.0973            0.0947
## Detection Prevalence         0.1006     0.1027            0.1015
## Balanced Accuracy            0.9580     0.9835            0.9697For multiclass classification, we will need to explicitly convert our vectors of labels to matrices of dummy variables.
y_train <- dummy_cols(train_labels_char, remove_selected_columns = TRUE) %>% setNames(class_names)
y_test <- dummy_cols(test_labels_char, remove_selected_columns = TRUE) %>% setNames(class_names)
head(y_train)Here are our datasets
set.seed(158)
train_df <- fashion_mnist_train %>%
  select(-label) %>%
  cbind(y_train)
test_df <- fashion_mnist_test %>%
  select(-label) %>%
  cbind(y_test)
train_df %>%
  select(pixel305:pixel310, Trouser, Shirt, Bag) %>%
  sample_n(10)Since the response variable is 10-dimensional, we will need to create a special formula for predicting
f <- class_names %>% 
  paste(collapse = " + ") %>%
  paste("~ .") %>%
  as.formula()
f## T_shirt_top + Trouser + Pullover + Dress + Coat + Sandal + Shirt + 
##     Sneaker + Bag + Ankle_boot ~ .
## <environment: 0x11dde8538>Now we will train a model with one hidden layer with 10 units.
mod_nn <- neuralnet(f, 
                    data = train_df, hidden = c(10),
                    stepmax = 1000,
                    threshold = 50,
                    lifesign = "full",
                    lifesign.step = 100,
                    err.fct = "ce",
                    linear.output = FALSE)## hidden: 10    thresh: 50    rep: 1/1    steps:     100   min thresh: 355.029772451359
##                                                    200   min thresh: 174.398683265989
##                                                    300   min thresh: 55.8617951596633
##                                                    400   min thresh: 55.8617951596633
##                                                    500   min thresh: 55.8617951596633
##                                                    600   min thresh: 55.8617951596633
##                                                    700   min thresh: 54.9082071779552
##                                                    714   error: 107062.88718 time: 10.07 minssummary(mod_nn)##                     Length   Class      Mode    
## call                      10 -none-     call    
## response              600000 -none-     numeric 
## covariate           47040000 -none-     numeric 
## model.list                 2 -none-     list    
## err.fct                    1 -none-     function
## act.fct                    1 -none-     function
## linear.output              1 -none-     logical 
## data                     794 data.frame list    
## exclude                    0 -none-     NULL    
## net.result                 1 -none-     list    
## weights                    1 -none-     list    
## generalized.weights        1 -none-     list    
## startweights               1 -none-     list    
## result.matrix           7963 -none-     numericHow many parameters does our model have? First, find the answer based on your knowledge of lecture material and then verify it by printing the summary of the model.
Answer The input has \(28\times 28=784\) units, i.e., layer dimensions are \[ n_1=784,\quad n_2=10,\quad n_3=10 \] Thus the total number of parameters is \[ (784+1)\times 10 + (10+1)\times 10=7960 \]
This is how we can dimensions of these matrices
mod_nn$weights[[1]] %>% lapply(dim)## [[1]]
## [1] 785  10
## 
## [[2]]
## [1] 11 10And the total number of parameters:
mod_nn$weights[[1]] %>% lapply(dim) %>% sapply(prod) %>% sum## [1] 7960This is what raw predictions look like:
mod_nn %>% predict(fashion_mnist_test) %>% head##             [,1]         [,2]        [,3]        [,4]        [,5]         [,6]
## [1,] 0.640174955 0.0658487596 0.010963269 0.087205895 0.008027419 0.0004322397
## [2,] 0.040296847 0.8531541214 0.018011822 0.019602809 0.069062228 0.0001437993
## [3,] 0.004420297 0.0087816599 0.375128071 0.007880877 0.326293704 0.0009430098
## [4,] 0.004420297 0.0087816599 0.375128071 0.007880877 0.326293704 0.0009430098
## [5,] 0.003400619 0.2932075764 0.008434255 0.357872946 0.280149888 0.0009832037
## [6,] 0.158338363 0.0001074794 0.266215500 0.036566973 0.050180557 0.0028300134
##            [,7]         [,8]        [,9]        [,10]
## [1,] 0.15011072 2.198296e-34 0.008089964 6.418063e-05
## [2,] 0.06559408 2.365127e-33 0.003558540 1.420052e-04
## [3,] 0.21954092 3.936055e-33 0.033918875 7.309403e-04
## [4,] 0.21954092 3.936055e-33 0.033918875 7.309403e-04
## [5,] 0.02968224 3.086418e-34 0.007676339 2.964850e-05
## [6,] 0.41443740 3.658414e-34 0.074230896 3.304623e-04Raw predictions are probabilities rather than labels. To identify predicted labels, we will construct a special function. It will first find the position of the item with the largest predicted probability and then insert its label. Below is the test confusion matrix of our model.
predict_nn <- function(nn, dataset = fashion_mnist_test, label_vector = class_names) {
  predicted_positions <- nn %>%
    predict(dataset) %>%
    apply(1, which.max)
  as.factor(label_vector[predicted_positions])
}
mod_nn %>% predict_nn(test_df) %>%
  confusionMatrix(test_labels_char)## Warning in confusionMatrix.default(., test_labels_char): Levels are not in the
## same order for reference and data. Refactoring data to match.## Confusion Matrix and Statistics
## 
##              Reference
## Prediction    T_shirt_top Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag
##   T_shirt_top         745      26       14    69   11      5   213       0   3
##   Trouser               9     835       13    29   53      2    12       0   1
##   Pullover             20      21      841    19  725      0   522       0  82
##   Dress               125     112       11   852  173      0   104       0  37
##   Coat                  0       0       10     2    8      0     7       0   2
##   Sandal                0       0        0     0    1    291     0      28   5
##   Shirt                50       4       94    24   16      2   100       0  45
##   Sneaker               0       1        0     0    0    452     0     809  15
##   Bag                  49       1       14     4    3     94    31       5 429
##   Ankle_boot            2       0        3     1   10    154    11     158 381
##              Reference
## Prediction    Ankle_boot
##   T_shirt_top          3
##   Trouser              1
##   Pullover             2
##   Dress                0
##   Coat                 0
##   Sandal               1
##   Shirt                0
##   Sneaker            108
##   Bag                 57
##   Ankle_boot         828
## 
## Overall Statistics
##                                          
##                Accuracy : 0.5738         
##                  95% CI : (0.564, 0.5835)
##     No Information Rate : 0.1            
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.5264         
##                                          
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: T_shirt_top Class: Trouser Class: Pullover
## Sensitivity                      0.7450         0.8350          0.8410
## Specificity                      0.9618         0.9867          0.8454
## Pos Pred Value                   0.6841         0.8743          0.3768
## Neg Pred Value                   0.9714         0.9818          0.9795
## Prevalence                       0.1000         0.1000          0.1000
## Detection Rate                   0.0745         0.0835          0.0841
## Detection Prevalence             0.1089         0.0955          0.2232
## Balanced Accuracy                0.8534         0.9108          0.8432
##                      Class: Dress Class: Coat Class: Sandal Class: Shirt
## Sensitivity                0.8520      0.0080        0.2910       0.1000
## Specificity                0.9376      0.9977        0.9961       0.9739
## Pos Pred Value             0.6025      0.2759        0.8926       0.2985
## Neg Pred Value             0.9828      0.9005        0.9267       0.9069
## Prevalence                 0.1000      0.1000        0.1000       0.1000
## Detection Rate             0.0852      0.0008        0.0291       0.0100
## Detection Prevalence       0.1414      0.0029        0.0326       0.0335
## Balanced Accuracy          0.8948      0.5028        0.6436       0.5369
##                      Class: Sneaker Class: Bag Class: Ankle_boot
## Sensitivity                  0.8090     0.4290            0.8280
## Specificity                  0.9360     0.9713            0.9200
## Pos Pred Value               0.5841     0.6245            0.5349
## Neg Pred Value               0.9778     0.9387            0.9796
## Prevalence                   0.1000     0.1000            0.1000
## Detection Rate               0.0809     0.0429            0.0828
## Detection Prevalence         0.1385     0.0687            0.1548
## Balanced Accuracy            0.8725     0.7002            0.8740As we see, the model’s performance is not exactly spectacular. It is actually possible to create a deep learning model that will outperform random forest, but it’s got to be a convolutional neural network (special architecture) trained on a modern deep learning framework (such as keras or torch) and it will require a lot of tuning. This goes beyond our course, especially since finding a right combination of hyperparameters for a deep leaning model requires a lot of time and effort. In the next section, you will see a tip of the iceberg of tuning a deep learning model.
Artificial neural networks usually have a large number of parameters (weights and biases defining model layers) and a large number of hyperparameters that control the training process. One can choose the number of layers and layer dimensions, the learning rate, the optimization algorithm, activation functions in hidden layers etc. You can read the manual on the function neuralnet to find the list of controls available.
Below we train a new model, with two hidden layers (10 units and 5 units respectively) and using the algorithm called “slr” rather than “rprop+”. If you are super-patient, you can also try a smaller value of threshold and a large value of stepmax.
set.seed(158)
mod_nn_2 <- neuralnet(f,
                    data = train_df, hidden = c(10, 5),
                    stepmax = 1000,
                    threshold = 50,
                    lifesign = "full",
                    lifesign.step = 10,
                    err.fct = "ce",
                    algorithm = "slr",
                    linear.output = FALSE)## hidden: 10, 5    thresh: 50    rep: 1/1    steps:      10    min thresh: 2196.46647194059
##                                                        20    min thresh: 1763.73885525966
##                                                        30    min thresh: 1305.29781962469
##                                                        40    min thresh: 444.983399972875
##                                                        50    min thresh: 381.514121566101
##                                                        60    min thresh: 308.199536800102
##                                                        70    min thresh: 260.695932471294
##                                                        80    min thresh: 260.695932471294
##                                                        90    min thresh: 260.695932471294
##                                                       100    min thresh: 260.695932471294
##                                                       110    min thresh: 222.386316954971
##                                                       120    min thresh: 222.386316954971
##                                                       130    min thresh: 222.386316954971
##                                                       140    min thresh: 222.386316954971
##                                                       150    min thresh: 222.386316954971
##                                                       160    min thresh: 222.386316954971
##                                                       170    min thresh: 173.733595514207
##                                                       180    min thresh: 173.733595514207
##                                                       190    min thresh: 173.733595514207
##                                                       200    min thresh: 173.733595514207
##                                                       210    min thresh: 173.733595514207
##                                                       220    min thresh: 173.733595514207
##                                                       230    min thresh: 173.733595514207
##                                                       240    min thresh: 173.733595514207
##                                                       250    min thresh: 169.459673421583
##                                                       260    min thresh: 159.75651651691
##                                                       270    min thresh: 159.75651651691
##                                                       280    min thresh: 111.642730676291
##                                                       290    min thresh: 111.642730676291
##                                                       300    min thresh: 111.642730676291
##                                                       310    min thresh: 111.642730676291
##                                                       320    min thresh: 111.642730676291
##                                                       330    min thresh: 111.642730676291
##                                                       340    min thresh: 111.642730676291
##                                                       350    min thresh: 111.642730676291
##                                                       360    min thresh: 96.9121750173971
##                                                       370    min thresh: 60.2778975343813
##                                                       380    min thresh: 60.2778975343813
##                                                       390    min thresh: 60.2778975343813
##                                                       400    min thresh: 60.2778975343813
##                                                       410    min thresh: 60.2778975343813
##                                                       420    min thresh: 60.2778975343813
##                                                       430    min thresh: 60.2778975343813
##                                                       440    min thresh: 60.2778975343813
##                                                       450    min thresh: 60.2778975343813
##                                                       460    min thresh: 60.2778975343813
##                                                       470    min thresh: 60.2778975343813
##                                                       480    min thresh: 60.2778975343813
##                                                       490    min thresh: 60.2778975343813
##                                                       500    min thresh: 60.2778975343813
##                                                       510    min thresh: 60.2778975343813
##                                                       520    min thresh: 60.2778975343813
##                                                       530    min thresh: 60.2778975343813
##                                                       540    min thresh: 60.2778975343813
##                                                       550    min thresh: 60.2778975343813
##                                                       560    min thresh: 60.2778975343813
##                                                       570    min thresh: 60.2778975343813
##                                                       580    min thresh: 60.2778975343813
##                                                       590    min thresh: 60.2778975343813
##                                                       600    min thresh: 60.2778975343813
##                                                       610    error: 108487.54983 time: 8.83 minssummary(mod_nn_2)##                     Length   Class      Mode    
## call                      11 -none-     call    
## response              600000 -none-     numeric 
## covariate           47040000 -none-     numeric 
## model.list                 2 -none-     list    
## err.fct                    1 -none-     function
## act.fct                    1 -none-     function
## linear.output              1 -none-     logical 
## data                     794 data.frame list    
## exclude                    0 -none-     NULL    
## net.result                 1 -none-     list    
## weights                    1 -none-     list    
## generalized.weights        1 -none-     list    
## startweights               1 -none-     list    
## result.matrix           7968 -none-     numericAnd here is the confusion matrix of the new model:
mod_nn_2 %>% predict_nn(test_df) %>%
  confusionMatrix(test_labels_char)## Warning in confusionMatrix.default(., test_labels_char): Levels are not in the
## same order for reference and data. Refactoring data to match.## Confusion Matrix and Statistics
## 
##              Reference
## Prediction    T_shirt_top Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag
##   T_shirt_top           0       0        0     0    0      0     1       0   0
##   Trouser              88     932       14   787  101      0    48       0   1
##   Pullover            789      29      903    58  721      1   774       0  22
##   Dress                13       6        6    42   13      4    12       0   5
##   Coat                 73      29       34    95  136      1    99       0   1
##   Sandal                3       1        1     0    1    571     1      48  22
##   Shirt                 3       0        0     0    0      1     2       0   1
##   Sneaker               0       1        0    13    1    104     0     626   4
##   Bag                  30       2       41     5   27     14    63       2 937
##   Ankle_boot            1       0        1     0    0    304     0     324   7
##              Reference
## Prediction    Ankle_boot
##   T_shirt_top          1
##   Trouser              0
##   Pullover             0
##   Dress                0
##   Coat                 0
##   Sandal              46
##   Shirt                0
##   Sneaker             13
##   Bag                  2
##   Ankle_boot         938
## 
## Overall Statistics
##                                           
##                Accuracy : 0.5087          
##                  95% CI : (0.4989, 0.5185)
##     No Information Rate : 0.1             
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.4541          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: T_shirt_top Class: Trouser Class: Pullover
## Sensitivity                      0.0000         0.9320          0.9030
## Specificity                      0.9998         0.8846          0.7340
## Pos Pred Value                   0.0000         0.4729          0.2739
## Neg Pred Value                   0.9000         0.9915          0.9855
## Prevalence                       0.1000         0.1000          0.1000
## Detection Rate                   0.0000         0.0932          0.0903
## Detection Prevalence             0.0002         0.1971          0.3297
## Balanced Accuracy                0.4999         0.9083          0.8185
##                      Class: Dress Class: Coat Class: Sandal Class: Shirt
## Sensitivity                0.0420      0.1360        0.5710       0.0020
## Specificity                0.9934      0.9631        0.9863       0.9994
## Pos Pred Value             0.4158      0.2906        0.8228       0.2857
## Neg Pred Value             0.9032      0.9094        0.9539       0.9001
## Prevalence                 0.1000      0.1000        0.1000       0.1000
## Detection Rate             0.0042      0.0136        0.0571       0.0002
## Detection Prevalence       0.0101      0.0468        0.0694       0.0007
## Balanced Accuracy          0.5177      0.5496        0.7787       0.5007
##                      Class: Sneaker Class: Bag Class: Ankle_boot
## Sensitivity                  0.6260     0.9370            0.9380
## Specificity                  0.9849     0.9793            0.9292
## Pos Pred Value               0.8215     0.8344            0.5956
## Neg Pred Value               0.9595     0.9929            0.9926
## Prevalence                   0.1000     0.1000            0.1000
## Detection Rate               0.0626     0.0937            0.0938
## Detection Prevalence         0.0762     0.1123            0.1575
## Balanced Accuracy            0.8054     0.9582            0.9336Find the test image for which the model is least confident in its prediction. Was it correct for predicting that image’s class?
conf_level <- mod_nn_2 %>% predict(test_df) %>% apply(1, max)
ind_of_least_confidence <- which.min(conf_level)
cat("The model is least confident in predicting test image",
      ind_of_least_confidence, "\n")## The model is least confident in predicting test image 1977cat("Below are class probabilities for test image", ind_of_least_confidence, ":\n")## Below are class probabilities for test image 1977 :mod_nn_2 %>% predict(test_df[ind_of_least_confidence , ]) %>% round(3)##       [,1]  [,2]  [,3]  [,4]  [,5] [,6]  [,7] [,8]  [,9] [,10]
## 1977 0.018 0.002 0.005 0.011 0.004    0 0.004    0 0.001 0.018cat("Predicted class is", as.character(predict_nn(mod_nn_2, test_df[ind_of_least_confidence , ])),
    "\n")## Predicted class is T_shirt_topcat("The actual class is", as.character(test_labels_char[ind_of_least_confidence]), "\n")## The actual class is Ankle_bootAnd here is the image the model is least confident about:
plot_fashion_mnist_image(ind_of_least_confidence, fashion_mnist_test)If our deep learning model is too flexible, i.e., if it has too many parameters, it may be prone to overfitting (which doesn’t happen here, clearly). One possible way to prevent overfitting is not to have a super-flexible model, but in that case, the model may be underfitting. A common method to construct flexible models that neither overfit nor underfit is regularization.
We are not going to learn about regularization in this lab (because, again, it takes a lot of time and effort). Below is just some basic information for you
Most common techniques to regularize neural networks are \(l_1/l_2\) regularization (similar to ridge / LASSO) and dropout. Both techniques are applied at a particular layer.
Let’s say we introduce \(l_1\)-regularization in the first hidden layer with coefficient \(0.005\). It means that the loss function will get the extra term \[ +0.005\sum_{i=1}^{n_2}\sum_{j=1}^{n_1}|w^{(1)}_{i,j}| \]
Let’s say that we introduce a dropout regularization in the second hidden layer with \(p=0.05\). It means that when calculating the model’s output, the output of every unit in the second hidden layer will be replaced with 0 with probability \(0.05\). This operation will prevent the model’s output to be too much dependent on a particular unit.
The point of this section is to see that regularizing a deep learning model takes a lot of effort - we need to choose which layers to regularize, regularization type, and regularization constants (they are also hyperparameters of the model).