Building a functional neural net from the MNIST dataset represents a “hello world” problem.
MNIST consists of 28 x 28 pixels greyscale images representing numbers between 0 and 9. This forms the neural net input which is trained on a set of integers represented by the integer datatype.
In this example, we build a functional neural net from this dataset.
Images are in x and labels are in y:
x is a 3D array. In this case, 60,000 2D sub-arrays composed of 28 x 28 integers with values from [0, 255], with each integer representing a shade on the greyscale, which is then assigned to a pixel, which is displayed as an image on a grid.
y contains a 2D array containing the actual integer being represented visually by x.
library(keras)
# Training data
mnist <- dataset_mnist()
train_images <- mnist$train$x
train_labels <- mnist$train$y
# Plot the image formed by the 2nd subtensor in the dataset
digit <- train_images[2,,]
plot(as.raster(digit, max = 255))
# Testing data
test_images <- mnist$test$x
test_labels <- mnist$test$y
str(train_images)
int [1:60000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
str(train_labels)
int [1:60000(1d)] 5 0 4 1 9 2 1 3 1 4 ...
str(test_images)
int [1:10000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
str(test_labels)
int [1:10000(1d)] 7 2 1 0 4 1 4 9 5 9 ...
Setup the neural network architecture that we will feed the training data into to yield a model that we will test with the testing data.
The tensor operation that will be applied to the input or training data is designated by “relu” in the model, where “relu” translates, in R, to the following: output = pmax(W %*% input) + b), 0).
In words, for this case, it means:
input tensor (training data) and a weighting tensor, W (kernel),b (bias), andThese operations are then repeated for successive applications of network layers (i.e., forward passes), by minimizing a loss function,1 until the data is effectively “unfolded”. The weights (or trainable parameters), W and b, contain the information learned by the network from exposure to training data and are adjusted slightly after computing network loss between iterations.2
If we take the derivative of the resulting tensor operation (i.e., the gradient of the loss), the direction of increasing loss can be efficiently calculated (i.e., a backward pass) allowing for efficient determination of new weighting functions between iterations.
This is the heart of solving neural nets. This general method is called gradient descent.
network <- keras_model_sequential() %>%
layer_dense(units = 512, activation = "relu", input_shape = c(28 * 28)) %>%
layer_dense(units = 10, activation = "softmax")
To make the network ready for training, we specify 3 more options to form the compilation step. The network is modified in place (note use of %>%) rather than returning a new network object. Keras models are modified in place because they are directed acyclic graphs of layers whose state is updated during training.
categorical_crossentropy is the loss function used as a feedback signal for learning the weight tensors, and which the training phase will attempt to minimize. This reduction of the loss happens via mini-batch stochastic gradient descent. The exact rules governing a specific use of gradient descent are defined by the rmsprop optimizer passed as the first argument.
network %>% compile(
optimizer = "rmsprop",
loss = "categorical_crossentropy",
metrics = c("accuracy")
)
Preprocess the data by reshaping it into the shape the network expects and scaling it so that all values are in [0, 1].
array_reshape() is used instead of dim<-() since we are dealing with tensors. Also, so that the data is reinterpreted using row-major semantics (as opposed to R’s default column-major semantics), which is in turn compatible with the way the numerical libraries called by Keras (NumPy, TensorFlow, and so on) interpret array dimensions. You should always use the array_reshape() function when reshaping R arrays that will be passed to Keras.
train_images <- array_reshape(train_images, c(60000, 28 * 28))
train_images <- train_images / 255
test_images <- array_reshape(test_images, c(10000, 28 * 28))
test_images <- test_images / 255
Categorically encode the labels.
train_labels <- to_categorical(train_labels)
test_labels <- to_categorical(test_labels)
In Keras, this is done by using the fit method. That is, the model is fit to its training data.
The network will start to iterate on the training data in mini-batches of 128 samples, 5 times over (each iteration over all the training data is called an epoch). At each iteration, the network will compute the gradients of the weights with regard to the loss on the batch, and update the weights accordingly.
network %>% fit(train_images, train_labels, epochs = 5, batch_size = 128)
Epoch 1/5
1/469 [..............................] - ETA: 5:07 - loss: 2.3593 - accuracy: 0.0938
2/469 [..............................] - ETA: 25s - loss: 2.1967 - accuracy: 0.2148
4/469 [..............................] - ETA: 18s - loss: 1.7626 - accuracy: 0.4453
10/469 [..............................] - ETA: 9s - loss: 1.3313 - accuracy: 0.6070
17/469 [>.............................] - ETA: 6s - loss: 1.0430 - accuracy: 0.7050
26/469 [>.............................] - ETA: 5s - loss: 0.8749 - accuracy: 0.7551
38/469 [=>............................] - ETA: 3s - loss: 0.7478 - accuracy: 0.7895
49/469 [==>...........................] - ETA: 3s - loss: 0.6722 - accuracy: 0.8098
60/469 [==>...........................] - ETA: 3s - loss: 0.6160 - accuracy: 0.8257
69/469 [===>..........................] - ETA: 2s - loss: 0.5786 - accuracy: 0.8361
81/469 [====>.........................] - ETA: 2s - loss: 0.5413 - accuracy: 0.8456
92/469 [====>.........................] - ETA: 2s - loss: 0.5128 - accuracy: 0.8539
106/469 [=====>........................] - ETA: 2s - loss: 0.4863 - accuracy: 0.8606
121/469 [======>.......................] - ETA: 2s - loss: 0.4637 - accuracy: 0.8666
132/469 [=======>......................] - ETA: 1s - loss: 0.4482 - accuracy: 0.8716
140/469 [=======>......................] - ETA: 1s - loss: 0.4368 - accuracy: 0.8748
149/469 [========>.....................] - ETA: 1s - loss: 0.4270 - accuracy: 0.8779
165/469 [=========>....................] - ETA: 1s - loss: 0.4090 - accuracy: 0.8828
180/469 [==========>...................] - ETA: 1s - loss: 0.3932 - accuracy: 0.8874
193/469 [===========>..................] - ETA: 1s - loss: 0.3822 - accuracy: 0.8904
208/469 [============>.................] - ETA: 1s - loss: 0.3695 - accuracy: 0.8940
227/469 [=============>................] - ETA: 1s - loss: 0.3565 - accuracy: 0.8971
249/469 [==============>...............] - ETA: 1s - loss: 0.3422 - accuracy: 0.9009
270/469 [================>.............] - ETA: 0s - loss: 0.3295 - accuracy: 0.9045
284/469 [=================>............] - ETA: 0s - loss: 0.3219 - accuracy: 0.9066
301/469 [==================>...........] - ETA: 0s - loss: 0.3153 - accuracy: 0.9082
312/469 [==================>...........] - ETA: 0s - loss: 0.3111 - accuracy: 0.9095
330/469 [====================>.........] - ETA: 0s - loss: 0.3037 - accuracy: 0.9117
352/469 [=====================>........] - ETA: 0s - loss: 0.2948 - accuracy: 0.9145
374/469 [======================>.......] - ETA: 0s - loss: 0.2868 - accuracy: 0.9165
397/469 [========================>.....] - ETA: 0s - loss: 0.2798 - accuracy: 0.9185
417/469 [=========================>....] - ETA: 0s - loss: 0.2747 - accuracy: 0.9199
440/469 [===========================>..] - ETA: 0s - loss: 0.2672 - accuracy: 0.9219
458/469 [============================>.] - ETA: 0s - loss: 0.2621 - accuracy: 0.9235
469/469 [==============================] - 2s 4ms/step - loss: 0.2601 - accuracy: 0.9242
469/469 [==============================] - 2s 4ms/step - loss: 0.2601 - accuracy: 0.9242
Epoch 2/5
1/469 [..............................] - ETA: 1s - loss: 0.0840 - accuracy: 0.9766
19/469 [>.............................] - ETA: 1s - loss: 0.1283 - accuracy: 0.9659
38/469 [=>............................] - ETA: 1s - loss: 0.1086 - accuracy: 0.9702
59/469 [==>...........................] - ETA: 1s - loss: 0.1105 - accuracy: 0.9688
79/469 [====>.........................] - ETA: 1s - loss: 0.1114 - accuracy: 0.9679
98/469 [=====>........................] - ETA: 0s - loss: 0.1116 - accuracy: 0.9682
119/469 [======>.......................] - ETA: 0s - loss: 0.1135 - accuracy: 0.9674
139/469 [=======>......................] - ETA: 0s - loss: 0.1122 - accuracy: 0.9680
155/469 [========>.....................] - ETA: 0s - loss: 0.1137 - accuracy: 0.9677
168/469 [=========>....................] - ETA: 0s - loss: 0.1137 - accuracy: 0.9673
188/469 [===========>..................] - ETA: 0s - loss: 0.1114 - accuracy: 0.9682
208/469 [============>.................] - ETA: 0s - loss: 0.1119 - accuracy: 0.9680
228/469 [=============>................] - ETA: 0s - loss: 0.1114 - accuracy: 0.9680
248/469 [==============>...............] - ETA: 0s - loss: 0.1111 - accuracy: 0.9681
269/469 [================>.............] - ETA: 0s - loss: 0.1102 - accuracy: 0.9684
293/469 [=================>............] - ETA: 0s - loss: 0.1103 - accuracy: 0.9683
317/469 [===================>..........] - ETA: 0s - loss: 0.1087 - accuracy: 0.9687
340/469 [====================>.........] - ETA: 0s - loss: 0.1084 - accuracy: 0.9688
364/469 [======================>.......] - ETA: 0s - loss: 0.1078 - accuracy: 0.9688
386/469 [=======================>......] - ETA: 0s - loss: 0.1068 - accuracy: 0.9690
412/469 [=========================>....] - ETA: 0s - loss: 0.1050 - accuracy: 0.9695
436/469 [==========================>...] - ETA: 0s - loss: 0.1041 - accuracy: 0.9696
460/469 [============================>.] - ETA: 0s - loss: 0.1033 - accuracy: 0.9697
469/469 [==============================] - 1s 2ms/step - loss: 0.1032 - accuracy: 0.9697
469/469 [==============================] - 1s 2ms/step - loss: 0.1032 - accuracy: 0.9697
Epoch 3/5
1/469 [..............................] - ETA: 1s - loss: 0.0965 - accuracy: 0.9609
24/469 [>.............................] - ETA: 0s - loss: 0.0713 - accuracy: 0.9775
48/469 [==>...........................] - ETA: 0s - loss: 0.0688 - accuracy: 0.9780
72/469 [===>..........................] - ETA: 0s - loss: 0.0712 - accuracy: 0.9784
97/469 [=====>........................] - ETA: 0s - loss: 0.0700 - accuracy: 0.9787
121/469 [======>.......................] - ETA: 0s - loss: 0.0710 - accuracy: 0.9788
145/469 [========>.....................] - ETA: 0s - loss: 0.0698 - accuracy: 0.9792
155/469 [========>.....................] - ETA: 0s - loss: 0.0689 - accuracy: 0.9795
166/469 [=========>....................] - ETA: 0s - loss: 0.0686 - accuracy: 0.9796
181/469 [==========>...................] - ETA: 0s - loss: 0.0688 - accuracy: 0.9798
194/469 [===========>..................] - ETA: 0s - loss: 0.0708 - accuracy: 0.9791
217/469 [============>.................] - ETA: 0s - loss: 0.0702 - accuracy: 0.9792
241/469 [==============>...............] - ETA: 0s - loss: 0.0698 - accuracy: 0.9793
266/469 [================>.............] - ETA: 0s - loss: 0.0695 - accuracy: 0.9797
289/469 [=================>............] - ETA: 0s - loss: 0.0695 - accuracy: 0.9797
312/469 [==================>...........] - ETA: 0s - loss: 0.0688 - accuracy: 0.9798
335/469 [====================>.........] - ETA: 0s - loss: 0.0688 - accuracy: 0.9796
358/469 [=====================>........] - ETA: 0s - loss: 0.0684 - accuracy: 0.9797
380/469 [=======================>......] - ETA: 0s - loss: 0.0687 - accuracy: 0.9797
402/469 [========================>.....] - ETA: 0s - loss: 0.0686 - accuracy: 0.9795
424/469 [==========================>...] - ETA: 0s - loss: 0.0681 - accuracy: 0.9797
446/469 [===========================>..] - ETA: 0s - loss: 0.0677 - accuracy: 0.9799
469/469 [==============================] - 1s 2ms/step - loss: 0.0676 - accuracy: 0.9799
469/469 [==============================] - 1s 2ms/step - loss: 0.0676 - accuracy: 0.9799
Epoch 4/5
1/469 [..............................] - ETA: 1s - loss: 0.0169 - accuracy: 0.9922
23/469 [>.............................] - ETA: 1s - loss: 0.0419 - accuracy: 0.9885
40/469 [=>............................] - ETA: 1s - loss: 0.0462 - accuracy: 0.9865
57/469 [==>...........................] - ETA: 1s - loss: 0.0479 - accuracy: 0.9857
78/469 [===>..........................] - ETA: 1s - loss: 0.0481 - accuracy: 0.9859
99/469 [=====>........................] - ETA: 0s - loss: 0.0493 - accuracy: 0.9853
119/469 [======>.......................] - ETA: 0s - loss: 0.0492 - accuracy: 0.9858
140/469 [=======>......................] - ETA: 0s - loss: 0.0492 - accuracy: 0.9855
162/469 [=========>....................] - ETA: 0s - loss: 0.0495 - accuracy: 0.9855
184/469 [==========>...................] - ETA: 0s - loss: 0.0489 - accuracy: 0.9857
206/469 [============>.................] - ETA: 0s - loss: 0.0488 - accuracy: 0.9860
228/469 [=============>................] - ETA: 0s - loss: 0.0488 - accuracy: 0.9858
251/469 [===============>..............] - ETA: 0s - loss: 0.0488 - accuracy: 0.9857
274/469 [================>.............] - ETA: 0s - loss: 0.0485 - accuracy: 0.9857
297/469 [=================>............] - ETA: 0s - loss: 0.0479 - accuracy: 0.9860
320/469 [===================>..........] - ETA: 0s - loss: 0.0477 - accuracy: 0.9859
343/469 [====================>.........] - ETA: 0s - loss: 0.0482 - accuracy: 0.9858
366/469 [======================>.......] - ETA: 0s - loss: 0.0482 - accuracy: 0.9858
390/469 [=======================>......] - ETA: 0s - loss: 0.0482 - accuracy: 0.9858
414/469 [=========================>....] - ETA: 0s - loss: 0.0487 - accuracy: 0.9856
438/469 [===========================>..] - ETA: 0s - loss: 0.0483 - accuracy: 0.9858
462/469 [============================>.] - ETA: 0s - loss: 0.0486 - accuracy: 0.9857
469/469 [==============================] - 1s 2ms/step - loss: 0.0488 - accuracy: 0.9857
469/469 [==============================] - 1s 2ms/step - loss: 0.0488 - accuracy: 0.9857
Epoch 5/5
1/469 [..............................] - ETA: 1s - loss: 0.0459 - accuracy: 0.9766
24/469 [>.............................] - ETA: 0s - loss: 0.0378 - accuracy: 0.9896
47/469 [==>...........................] - ETA: 0s - loss: 0.0347 - accuracy: 0.9895
69/469 [===>..........................] - ETA: 0s - loss: 0.0364 - accuracy: 0.9881
91/469 [====>.........................] - ETA: 0s - loss: 0.0336 - accuracy: 0.9892
114/469 [======>.......................] - ETA: 0s - loss: 0.0355 - accuracy: 0.9883
138/469 [=======>......................] - ETA: 0s - loss: 0.0352 - accuracy: 0.9884
161/469 [=========>....................] - ETA: 0s - loss: 0.0358 - accuracy: 0.9885
185/469 [==========>...................] - ETA: 0s - loss: 0.0363 - accuracy: 0.9881
209/469 [============>.................] - ETA: 0s - loss: 0.0370 - accuracy: 0.9882
232/469 [=============>................] - ETA: 0s - loss: 0.0371 - accuracy: 0.9881
257/469 [===============>..............] - ETA: 0s - loss: 0.0371 - accuracy: 0.9882
282/469 [=================>............] - ETA: 0s - loss: 0.0368 - accuracy: 0.9883
306/469 [==================>...........] - ETA: 0s - loss: 0.0374 - accuracy: 0.9882
328/469 [===================>..........] - ETA: 0s - loss: 0.0374 - accuracy: 0.9882
348/469 [=====================>........] - ETA: 0s - loss: 0.0375 - accuracy: 0.9881
369/469 [======================>.......] - ETA: 0s - loss: 0.0376 - accuracy: 0.9881
391/469 [========================>.....] - ETA: 0s - loss: 0.0377 - accuracy: 0.9880
413/469 [=========================>....] - ETA: 0s - loss: 0.0375 - accuracy: 0.9881
435/469 [==========================>...] - ETA: 0s - loss: 0.0369 - accuracy: 0.9883
456/469 [============================>.] - ETA: 0s - loss: 0.0369 - accuracy: 0.9883
469/469 [==============================] - 1s 2ms/step - loss: 0.0371 - accuracy: 0.9882
469/469 [==============================] - 1s 2ms/step - loss: 0.0371 - accuracy: 0.9882
Two quantities are displayed during training: the loss of the network over the training data, and the accuracy of the network over the training data, which is displayed here. The final accuracy is 99%.
After the 5 epochs, the network has performed 2,345 gradient updates (469 per epoch). The loss of the network is sufficiently low that the network is capable of classifying handwritten digits with high accuracy.
Check that the model performs on the test set too.
metrics <- network %>% evaluate(test_images, test_labels)
1/313 [..............................] - ETA: 18s - loss: 0.0046 - accuracy: 1.0000
66/313 [=====>........................] - ETA: 0s - loss: 0.0904 - accuracy: 0.9740
161/313 [==============>...............] - ETA: 0s - loss: 0.0863 - accuracy: 0.9750
265/313 [========================>.....] - ETA: 0s - loss: 0.0681 - accuracy: 0.9794
313/313 [==============================] - 0s 568us/step - loss: 0.0670 - accuracy: 0.9794
313/313 [==============================] - 0s 569us/step - loss: 0.0670 - accuracy: 0.9794
metrics
loss accuracy
0.06699257 0.97939998
The test set accuracy turns out to be 98% – less than the model accuracy from training due to overfitting.
Generate predictions for the first 10 samples of the test set.
network %>% predict(test_images[1:10,]) %>% k_argmax() %>% as.numeric()
[1] 7 2 1 0 4 1 4 9 5 9