Keras on the MNIST dataset

The code bellow works with the keras and psycholate library. Most of the code was taken from Chollet and Allaire (1917) book: Deep Learning with R.

library(keras)
options(width=1000)
options(digits=3)
options(scippen=999)
par(mai=c(.001,.001,.001,.001),mar=c(.001,.001,.001,.001),oma=c(.001,.001,.001,.001))

Obtaining and preparing the data

The dataset is obtained from keras::dataset_mnist() function then usually we have to
- Split the dataset in train and test datasets.
- Scale the dataset (values in dataset should be -1 to 1 or 0 to 1).
- Scaling on the test dataset should be done using the mean and sd of the training dataset.
In this case, much of the aforementioned work is done already. In the mnist dataset, data are split into train and test, input and output. What is done bellow, is to simply rename the some variables to make our life a bit easier. The dataset consists of 60000 representations of train images and 10000 representations of test images and their corresponding labels.

mnist<-dataset_mnist()
tri<-train_images<-mnist$train$x
train_labels<-mnist$train$y
tei<-test_images<-mnist$test$x
test_labels<-mnist$test$y

How the data look like

Each image is a 28*28 pixel grayscale of handwritten numbers, and their respective number values are stored in the labels datasets. Lets see the first image in the test dataset:

tei[1,,] # this is how the image looks as a matrix representation
##       [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13] [,14] [,15] [,16] [,17] [,18] [,19] [,20] [,21] [,22] [,23] [,24] [,25] [,26] [,27] [,28]
##  [1,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0
##  [2,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0
##  [3,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0
##  [4,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0
##  [5,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0
##  [6,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0
##  [7,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0
##  [8,]    0    0    0    0    0    0   84  185  159   151    60    36     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0
##  [9,]    0    0    0    0    0    0  222  254  254   254   254   241   198   198   198   198   198   198   198   198   170    52     0     0     0     0     0     0
## [10,]    0    0    0    0    0    0   67  114   72   114   163   227   254   225   254   254   254   250   229   254   254   140     0     0     0     0     0     0
## [11,]    0    0    0    0    0    0    0    0    0     0     0    17    66    14    67    67    67    59    21   236   254   106     0     0     0     0     0     0
## [12,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0     0    83   253   209    18     0     0     0     0     0     0
## [13,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0    22   233   255    83     0     0     0     0     0     0     0
## [14,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0   129   254   238    44     0     0     0     0     0     0     0
## [15,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0    59   249   254    62     0     0     0     0     0     0     0     0
## [16,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0   133   254   187     5     0     0     0     0     0     0     0     0
## [17,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     9   205   248    58     0     0     0     0     0     0     0     0     0
## [18,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0   126   254   182     0     0     0     0     0     0     0     0     0     0
## [19,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0    75   251   240    57     0     0     0     0     0     0     0     0     0     0
## [20,]    0    0    0    0    0    0    0    0    0     0     0     0     0    19   221   254   166     0     0     0     0     0     0     0     0     0     0     0
## [21,]    0    0    0    0    0    0    0    0    0     0     0     0     3   203   254   219    35     0     0     0     0     0     0     0     0     0     0     0
## [22,]    0    0    0    0    0    0    0    0    0     0     0     0    38   254   254    77     0     0     0     0     0     0     0     0     0     0     0     0
## [23,]    0    0    0    0    0    0    0    0    0     0     0    31   224   254   115     1     0     0     0     0     0     0     0     0     0     0     0     0
## [24,]    0    0    0    0    0    0    0    0    0     0     0   133   254   254    52     0     0     0     0     0     0     0     0     0     0     0     0     0
## [25,]    0    0    0    0    0    0    0    0    0     0    61   242   254   254    52     0     0     0     0     0     0     0     0     0     0     0     0     0
## [26,]    0    0    0    0    0    0    0    0    0     0   121   254   254   219    40     0     0     0     0     0     0     0     0     0     0     0     0     0
## [27,]    0    0    0    0    0    0    0    0    0     0   121   254   207    18     0     0     0     0     0     0     0     0     0     0     0     0     0     0
## [28,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0
plot(as.raster(tei[1,,],max=255)) # this is how the image looks like as a raster representation

test_labels[1] # and this is it's corresponding value
## [1] 7

Data Wrangling

Each image is a 28 by 28 matrix with values ranging from 0 to 255, thus the input of the network has to be 28*28=784 units. Inputs should be reshaped in a format acceptable by the network input layer. Input data should also be scaled by dividing by 255. This scaling results in a 0 - 1 range.

train_images<-array_reshape(train_images,c(60000,28*28))
train_images<-train_images/255
test_images<-array_reshape(test_images,c(10000,28*28))
test_images<-test_images/255
train_labels<-to_categorical(train_labels)
test_labels<-to_categorical(test_labels)

Network specification

The specified network for our task, is a simple sequential network with an input layer, two hidden layers, and an output layer of 10 units. Each unit in each layer connects with all other units in their neighboring layer. Each unit in the output layer represents 10 digits from 1 to 0 [1 2 3 4 5 6 7 8 9 0]. The measures employed to evaluate the network accuracy is the root mean square, categorical cross entropy and accuracy. The choice of these methods partly depends on the nature of the output. In our case the output units should return 0 or 1.

network<-keras_model_sequential() %>%
  layer_dense(units=1000,activation="relu",input_shape=c(28*28)) %>%
  layer_dense(units=1000,activation="relu") %>%
  layer_dense(units=10,activation="softmax")
network %>% compile(optimizer="rmsprop",
                    loss="categorical_crossentropy",
                    metrics=c("accuracy"))
network %>% fit(train_images,train_labels,epochs=10,batch_size=784,verbose=0)

Network training

network %>% fit(train_images,train_labels,epochs=10,batch_size=784,verbose=0)
network %>% evaluate(test_images,test_labels)
## $loss
## [1] 0.0883
## 
## $acc
## [1] 0.985

Network evaluation

In order to see how well the network performs, we use the test dataset. The output of the test data is in a matrix form since each output unit represents a digit, and we have to classify 10 digits, the matrix is a 10*10000 dataset. However, after a short transformation it is possible to obtain a vector of integers from 0 to 9.

rtl<-c()
for(i in 1:nrow(test_labels))
  rtl<-c(rtl,which(test_labels[i,]==1)-1)
summary(rtl)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    0.00    2.00    4.00    4.44    7.00    9.00

Confusion matrix

An ideal model should produce a confusion matrix where all observed variables match the predicted variables. This is represented by having all high frequencies in the diagonal whilst all off diagonal values should ideally be 0. In fact, the confusion matrix indicates around 98% prediction accuracy from 10000 images.

op<-data.frame(predict=predict_classes(network,test_images),observe=rtl)
psycholatefunctions::confusion_matrix_percent(op$predict,op$observe)
##                      0       1       2       3      4      5      6       7      8       9  row_sum row_percent
## 0               975.00    0.00    2.00    0.00   1.00   3.00   4.00    0.00   6.00    1.00   992.00       98.29
## 1                 1.00 1129.00    1.00    0.00   1.00   0.00   2.00    1.00   0.00    3.00  1138.00       99.21
## 2                 0.00    1.00 1013.00    1.00   4.00   0.00   0.00    5.00   2.00    0.00  1026.00       98.73
## 3                 0.00    1.00    2.00  990.00   0.00   4.00   1.00    0.00   1.00    1.00  1000.00       99.00
## 4                 1.00    0.00    1.00    0.00 970.00   1.00   5.00    0.00   3.00   11.00   992.00       97.78
## 5                 0.00    0.00    0.00    4.00   0.00 877.00   1.00    0.00   2.00    1.00   885.00       99.10
## 6                 0.00    1.00    1.00    0.00   2.00   2.00 944.00    0.00   1.00    0.00   951.00       99.26
## 7                 1.00    1.00    6.00    5.00   0.00   1.00   0.00 1013.00   2.00    2.00  1031.00       98.25
## 8                 2.00    2.00    6.00    4.00   0.00   4.00   1.00    5.00 953.00    2.00   979.00       97.34
## 9                 0.00    0.00    0.00    6.00   4.00   0.00   0.00    4.00   4.00  988.00  1006.00       98.21
## collumn_sum     980.00 1135.00 1032.00 1010.00 982.00 892.00 958.00 1028.00 974.00 1009.00 10000.00      100.00
## collumn_percent  99.49   99.47   98.16   98.02  98.78  98.32  98.54   98.54  97.84   97.92   100.00       98.52

Scatterplot

Another descriptive representation can be done using a scatterplot. In that case, all observations (observed and predicted data) should fall in the regression line.

psycholatefunctions::plot_scatterplot(op)

differences<-setdiff(1:length(op$predict),which(op$predict==op$observe))

Errors

From the 10000 image test dataset, the network was unable to classify correctly 148 images. It may be interesting to visualize these miss classified images.

length(differences)
## [1] 148
for(i in differences)
  plot(as.raster(tei[i,,],max=255))