The code bellow works with the keras and psycholate library. Most of the code was taken from Chollet and Allaire (1917) book: Deep Learning with R.
library(tensorflow)
library(keras)
## The keras package is deprecated. Please use the keras3 package instead.
## Alternatively, to continue using legacy keras, call `py_require_legacy_keras()`.
library(ggplot2)
options(width=1000)
options(digits=3)
options(scippen=999)
par(mai=c(.001,.001,.001,.001),mar=c(.001,.001,.001,.001),oma=c(.001,.001,.001,.001))
The dataset is obtained from keras::dataset_mnist() function then
usually we have to
- Split the dataset in train and test datasets.
- Scale the dataset (values in dataset should be -1 to 1 or 0 to
1).
- Scaling on the test dataset should be done using the mean and sd of
the training dataset.
In this case, much of the aforementioned work is done already. In the
mnist dataset, data are split into train and test, input and output.
What is done bellow, is to simply rename the some variables to make our
life a bit easier. The dataset consists of 60000 representations of
train images and 10000 representations of test images and their
corresponding labels.
mnist<-dataset_mnist()
train_images<-mnist$train$x
train_labels<-mnist$train$y
test_images<-mnist$test$x
test_labels<-mnist$test$y
Each image is a 28*28 pixel grayscale of handwritten numbers, and their respective number values are stored in the labels datasets. Lets see the first image in the test dataset:
test_images[1,,] # this is how the image looks as a matrix representation
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13] [,14] [,15] [,16] [,17] [,18] [,19] [,20] [,21] [,22] [,23] [,24] [,25] [,26] [,27] [,28] ## [1,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [2,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [3,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [4,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [5,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [6,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [7,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [8,] 0 0 0 0 0 0 84 185 159 151 60 36 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [9,] 0 0 0 0 0 0 222 254 254 254 254 241 198 198 198 198 198 198 198 198 170 52 0 0 0 0 0 0 ## [10,] 0 0 0 0 0 0 67 114 72 114 163 227 254 225 254 254 254 250 229 254 254 140 0 0 0 0 0 0 ## [11,] 0 0 0 0 0 0 0 0 0 0 0 17 66 14 67 67 67 59 21 236 254 106 0 0 0 0 0 0 ## [12,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 83 253 209 18 0 0 0 0 0 0 ## [13,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 22 233 255 83 0 0 0 0 0 0 0 ## [14,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 129 254 238 44 0 0 0 0 0 0 0 ## [15,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 59 249 254 62 0 0 0 0 0 0 0 0 ## [16,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 133 254 187 5 0 0 0 0 0 0 0 0 ## [17,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9 205 248 58 0 0 0 0 0 0 0 0 0 ## [18,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 126 254 182 0 0 0 0 0 0 0 0 0 0 ## [19,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 75 251 240 57 0 0 0 0 0 0 0 0 0 0 ## [20,] 0 0 0 0 0 0 0 0 0 0 0 0 0 19 221 254 166 0 0 0 0 0 0 0 0 0 0 0 ## [21,] 0 0 0 0 0 0 0 0 0 0 0 0 3 203 254 219 35 0 0 0 0 0 0 0 0 0 0 0 ## [22,] 0 0 0 0 0 0 0 0 0 0 0 0 38 254 254 77 0 0 0 0 0 0 0 0 0 0 0 0 ## [23,] 0 0 0 0 0 0 0 0 0 0 0 31 224 254 115 1 0 0 0 0 0 0 0 0 0 0 0 0 ## [24,] 0 0 0 0 0 0 0 0 0 0 0 133 254 254 52 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [25,] 0 0 0 0 0 0 0 0 0 0 61 242 254 254 52 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [26,] 0 0 0 0 0 0 0 0 0 0 121 254 254 219 40 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [27,] 0 0 0 0 0 0 0 0 0 0 121 254 207 18 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ## [28,] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
plot(as.raster(test_images[1,,],max=255)) # this is how the image looks like as a raster representation
test_labels[1] # and this is it's corresponding value
## [1] 7
Each image is a 28 by 28 matrix with values ranging from 0 to 255, thus the input of the network has to be 28*28=784 units. Inputs should be reshaped in a format acceptable by the network input layer. Input data should also be scaled by dividing by 255. This scaling results in a 0 - 1 range.
train_images<-array_reshape(train_images,c(60000,28*28))
train_images<-train_images/255
test_images<-array_reshape(test_images,c(10000,28*28))
test_images<-test_images/255
train_labels<-to_categorical(train_labels)
test_labels<-to_categorical(test_labels)
The specified network for our task, is a simple sequential network with an input layer, two hidden layers, and an output layer of 10 units. Each unit in each layer connects with all other units in their neighboring layer. Each unit in the output layer represents 10 digits from 1 to 0 [1 2 3 4 5 6 7 8 9 0]. The measures employed to evaluate the network accuracy is the root mean square, categorical cross entropy and accuracy. The choice of these methods partly depends on the nature of the output. In our case the output units should return 0 or 1.
network<-keras_model_sequential() %>%
layer_dense(units=100,activation="relu",input_shape=c(28*28)) %>%
layer_dense(units=1000,activation="relu") %>%
layer_dense(units=1000,activation="relu") %>%
layer_dense(units=10,activation="softmax")
network %>% compile(optimizer="rmsprop",
loss="categorical_crossentropy",
metrics=c("accuracy"))
network %>% fit(train_images,train_labels,epochs=30,batch_size=784)
## Epoch 1/30 ## 77/77 - 1s - loss: 0.5938 - accuracy: 0.8103 - 934ms/epoch - 12ms/step ## Epoch 2/30 ## 77/77 - 1s - loss: 0.2171 - accuracy: 0.9325 - 711ms/epoch - 9ms/step ## Epoch 3/30 ## 77/77 - 1s - loss: 0.1452 - accuracy: 0.9547 - 730ms/epoch - 9ms/step ## Epoch 4/30 ## 77/77 - 1s - loss: 0.1096 - accuracy: 0.9663 - 687ms/epoch - 9ms/step ## Epoch 5/30 ## 77/77 - 1s - loss: 0.0867 - accuracy: 0.9723 - 806ms/epoch - 10ms/step ## Epoch 6/30 ## 77/77 - 1s - loss: 0.0717 - accuracy: 0.9777 - 678ms/epoch - 9ms/step ## Epoch 7/30 ## 77/77 - 1s - loss: 0.0593 - accuracy: 0.9814 - 715ms/epoch - 9ms/step ## Epoch 8/30 ## 77/77 - 1s - loss: 0.0497 - accuracy: 0.9844 - 783ms/epoch - 10ms/step ## Epoch 9/30 ## 77/77 - 1s - loss: 0.0421 - accuracy: 0.9872 - 688ms/epoch - 9ms/step ## Epoch 10/30 ## 77/77 - 1s - loss: 0.0335 - accuracy: 0.9898 - 828ms/epoch - 11ms/step ## Epoch 11/30 ## 77/77 - 1s - loss: 0.0323 - accuracy: 0.9901 - 797ms/epoch - 10ms/step ## Epoch 12/30 ## 77/77 - 1s - loss: 0.0241 - accuracy: 0.9923 - 773ms/epoch - 10ms/step ## Epoch 13/30 ## 77/77 - 1s - loss: 0.0216 - accuracy: 0.9930 - 763ms/epoch - 10ms/step ## Epoch 14/30 ## 77/77 - 1s - loss: 0.0188 - accuracy: 0.9940 - 766ms/epoch - 10ms/step ## Epoch 15/30 ## 77/77 - 1s - loss: 0.0170 - accuracy: 0.9948 - 745ms/epoch - 10ms/step ## Epoch 16/30 ## 77/77 - 1s - loss: 0.0138 - accuracy: 0.9958 - 787ms/epoch - 10ms/step ## Epoch 17/30 ## 77/77 - 1s - loss: 0.0122 - accuracy: 0.9962 - 773ms/epoch - 10ms/step ## Epoch 18/30 ## 77/77 - 1s - loss: 0.0103 - accuracy: 0.9974 - 806ms/epoch - 10ms/step ## Epoch 19/30 ## 77/77 - 1s - loss: 0.0117 - accuracy: 0.9966 - 830ms/epoch - 11ms/step ## Epoch 20/30 ## 77/77 - 1s - loss: 0.0072 - accuracy: 0.9981 - 813ms/epoch - 11ms/step ## Epoch 21/30 ## 77/77 - 1s - loss: 0.0068 - accuracy: 0.9980 - 790ms/epoch - 10ms/step ## Epoch 22/30 ## 77/77 - 1s - loss: 0.0081 - accuracy: 0.9975 - 801ms/epoch - 10ms/step ## Epoch 23/30 ## 77/77 - 1s - loss: 0.0049 - accuracy: 0.9987 - 733ms/epoch - 10ms/step ## Epoch 24/30 ## 77/77 - 1s - loss: 0.0035 - accuracy: 0.9991 - 765ms/epoch - 10ms/step ## Epoch 25/30 ## 77/77 - 1s - loss: 0.0068 - accuracy: 0.9981 - 667ms/epoch - 9ms/step ## Epoch 26/30 ## 77/77 - 1s - loss: 0.0048 - accuracy: 0.9986 - 756ms/epoch - 10ms/step ## Epoch 27/30 ## 77/77 - 1s - loss: 0.0046 - accuracy: 0.9987 - 761ms/epoch - 10ms/step ## Epoch 28/30 ## 77/77 - 1s - loss: 0.0033 - accuracy: 0.9992 - 785ms/epoch - 10ms/step ## Epoch 29/30 ## 77/77 - 1s - loss: 0.0027 - accuracy: 0.9993 - 806ms/epoch - 10ms/step ## Epoch 30/30 ## 77/77 - 1s - loss: 0.0057 - accuracy: 0.9987 - 786ms/epoch - 10ms/step
network %>% evaluate(test_images,test_labels)
## 313/313 - 0s - loss: 0.0864 - accuracy: 0.9812 - 337ms/epoch - 1ms/step
## loss accuracy ## 0.0864 0.9812
In order to see how well the network performs, we use the test dataset. The output of the test data is in a matrix form since each output unit represents a digit, and we have to classify 10 digits, the matrix is a 10*10000 dataset. However, after a short transformation it is possible to obtain a vector of integers from 0 to 9.
rtl<-c()
for(i in 1:nrow(test_labels))
rtl<-c(rtl,which(test_labels[i,]==1)-1)
summary(rtl)
## Min. 1st Qu. Median Mean 3rd Qu. Max. ## 0.00 2.00 4.00 4.44 7.00 9.00
deepviz::plot_model(network)
## Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if `.name_repair` is omitted as of tibble 2.0.0.
## ℹ Using compatibility `.name_repair`.
## ℹ The deprecated feature was likely used in the deepviz package.
## Please report the issue at <https://github.com/andrie/deepviz/issues>.
## This warning is displayed once per session.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
An ideal model should produce a confusion matrix where all observed variables match the predicted variables. This is represented by having all high frequencies in the diagonal whilst all off diagonal values should ideally be 0. In fact, the confusion matrix indicates around 98% prediction accuracy from 10000 images.
predicted <- as.integer(predict(network, test_images) %>% k_argmax())
## 313/313 - 0s - 284ms/epoch - 907us/step
observed <- apply(test_labels, 1, which.max) - 1
op <- data.frame(predict=predicted, observe=observed)
rwf::confusion_matrix_percent(op$predict, op$observe)
## Registered S3 method overwritten by 'lme4':
## method from
## na.action.merMod car
## 0 1 2 3 4 5 6 7 8 9 sum p ## 0 973.00 0.00 0.00 1.00 0.00 1.00 2.00 0.00 3.00 0.00 980.00 0.99 ## 1 1.00 1122.00 3.00 1.00 0.00 1.00 2.00 1.00 4.00 0.00 1135.00 0.99 ## 2 2.00 1.00 1008.00 5.00 1.00 0.00 4.00 4.00 7.00 0.00 1032.00 0.98 ## 3 0.00 0.00 3.00 994.00 0.00 4.00 0.00 1.00 5.00 3.00 1010.00 0.98 ## 4 2.00 0.00 2.00 0.00 962.00 0.00 4.00 4.00 1.00 7.00 982.00 0.98 ## 5 2.00 0.00 0.00 8.00 1.00 872.00 4.00 0.00 4.00 1.00 892.00 0.98 ## 6 3.00 1.00 2.00 0.00 5.00 5.00 940.00 0.00 2.00 0.00 958.00 0.98 ## 7 0.00 2.00 11.00 2.00 0.00 0.00 0.00 1005.00 3.00 5.00 1028.00 0.98 ## 8 0.00 1.00 4.00 3.00 1.00 4.00 3.00 3.00 952.00 3.00 974.00 0.98 ## 9 2.00 2.00 0.00 2.00 6.00 6.00 1.00 2.00 4.00 984.00 1009.00 0.98 ## sum 985.00 1129.00 1033.00 1016.00 976.00 893.00 960.00 1020.00 985.00 1003.00 10000.00 1.00 ## p 0.99 0.99 0.98 0.98 0.99 0.98 0.98 0.99 0.97 0.98 1.00 0.98
cm <- as.data.frame(table(observe=observed, predict=predicted))
library(ggplot2)
ggplot(cm, aes(x=predict, y=observe, fill=Freq)) +
geom_tile() +
geom_text(aes(label=Freq), size=3) +
scale_fill_gradient(low="white", high="steelblue") +
labs(x="Predicted", y="Observed", fill="Count") +
theme_minimal()
Another descriptive representation can be done using a scatterplot. In that case, all observations (observed and predicted data) should fall in the regression line.
rwf::plot_scatterplot(op)
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## ℹ The deprecated feature was likely used in the rwf package.
## Please report the issue to the authors.
## This warning is displayed once per session.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
differences<-setdiff(1:length(op$predict),which(op$predict==op$observe))
From the 10000 image test dataset, the network was unable to classify correctly 188 images. It may be interesting to visualize these miss classified images.
length(differences)
## [1] 188
for(i in differences)
plot(as.raster(mnist$test$x[i,,],max=255))