library(keras)
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y

The x data is a 3-d array (images,width,height) of grayscale values . To prepare the data for training we convert the 3-d arrays into matrices by reshaping width and height into a single dimension (28x28 images are flattened into length 784 vectors). Then, we convert the grayscale values from integers ranging between 0 to 255 into floating point values ranging between 0 and 1:

# reshape
dim(x_train) <- c(nrow(x_train), 784)
dim(x_test) <- c(nrow(x_test), 784)
# rescale
x_train <- x_train / 255
x_test <- x_test / 255

The y data is an integer vector with values ranging from 0 to 9. To prepare this data for training we one-hot encode the vectors into binary class matrices using the Keras to_categorical() function:

y_train <- to_categorical(y_train, 10)
y_test <- to_categorical(y_test, 10)

Defining the Model

The core data structure of Keras is a model, a way to organize layers. The simplest type of model is the Sequential model, a linear stack of layers.

We begin by creating a sequential model and then adding layers using the pipe (%>%) operator:

model <- keras_model_sequential() 
model %>% 
  layer_dense(units = 256, activation = 'relu', input_shape = c(784)) %>% 
  layer_dropout(rate = 0.4) %>% 
  layer_dense(units = 128, activation = 'relu') %>%
  layer_dropout(rate = 0.3) %>%
  layer_dense(units = 10, activation = 'softmax')

The input_shape argument to the first layer specifies the shape of the input data (a length 784 numeric vector representing a grayscale image). The final layer outputs a length 10 numeric vector (probabilities for each digit) using a softmax activation function.

Use the summary() function to print the details of the model:

summary(model)
________________________________________________________________________________________________________________________________________________________________________________________________________
Layer (type)                                                                              Output Shape                                                                    Param #                       
========================================================================================================================================================================================================
dense_7 (Dense)                                                                           (None, 256)                                                                     200960                        
________________________________________________________________________________________________________________________________________________________________________________________________________
dropout_5 (Dropout)                                                                       (None, 256)                                                                     0                             
________________________________________________________________________________________________________________________________________________________________________________________________________
dense_8 (Dense)                                                                           (None, 128)                                                                     32896                         
________________________________________________________________________________________________________________________________________________________________________________________________________
dropout_6 (Dropout)                                                                       (None, 128)                                                                     0                             
________________________________________________________________________________________________________________________________________________________________________________________________________
dense_9 (Dense)                                                                           (None, 10)                                                                      1290                          
========================================================================================================================================================================================================
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0
________________________________________________________________________________________________________________________________________________________________________________________________________

Next, compile the model with appropriate loss function, optimizer, and metrics:

model %>% compile(
  loss = 'categorical_crossentropy',
  optimizer = optimizer_rmsprop(),
  metrics = c('accuracy')
)

Training and Evaluation

Use the fit() function to train the model for 30 epochs using batches of 128 images:


history <- model %>% fit(
  x_train, y_train, 
  epochs = 30, batch_size = 128, 
  validation_split = 0.2
)

The history object returned by fit() includes loss and accuracy metrics which we can plot:

plot(history)

Evaluate the model’s performance on the test data:

model %>% evaluate(x_test, y_test)

   32/10000 [..............................] - ETA: 3s
  736/10000 [=>............................] - ETA: 0s
 1440/10000 [===>..........................] - ETA: 0s
 2176/10000 [=====>........................] - ETA: 0s
 2944/10000 [=======>......................] - ETA: 0s
 3648/10000 [=========>....................] - ETA: 0s
 4320/10000 [===========>..................] - ETA: 0s
 5056/10000 [==============>...............] - ETA: 0s
 5824/10000 [================>.............] - ETA: 0s
 6528/10000 [==================>...........] - ETA: 0s
 7264/10000 [====================>.........] - ETA: 0s
 7968/10000 [======================>.......] - ETA: 0s
 8640/10000 [========================>.....] - ETA: 0s
 9344/10000 [===========================>..] - ETA: 0s
10000/10000 [==============================] - 0s     
$loss
[1] 0.1094651

$acc
[1] 0.9785

Generate predictions on new data

model %>% predict_classes(x_test)
   [1] 7 2 1 0 4 1 4 9 6 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3 4 7 2 7 1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 9 3 7 4 6 4 3 0 7 0 2 9 1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1
  [98] 7 6 9 6 0 5 4 9 9 2 1 9 4 8 7 3 9 7 4 4 4 9 2 5 4 7 6 7 9 0 5 8 5 6 6 5 7 8 1 0 1 6 4 6 7 3 1 7 1 8 2 0 3 9 9 5 5 1 5 6 0 3 4 4 6 5 4 6 5 4 5 1 4 4 7 2 3 2 7 1 8 1 8 1 8 5 0 8 9 2 5 0 1 1 1 0 9
 [195] 0 3 1 6 4 2 3 6 1 1 1 3 9 5 2 9 4 5 9 3 9 0 3 6 5 5 7 2 2 7 1 2 8 4 1 7 3 3 8 8 7 9 2 2 4 1 5 9 8 7 2 3 0 6 4 2 4 1 9 5 7 7 2 8 2 0 8 5 7 7 9 1 8 1 8 0 3 0 1 9 9 4 1 8 2 1 2 9 7 5 9 2 6 4 1 5 8
 [292] 2 9 2 0 4 0 0 2 8 4 7 1 2 4 0 2 7 4 3 3 0 0 3 1 9 6 5 2 5 9 7 9 3 0 4 2 0 7 1 1 2 1 5 3 3 9 7 8 6 3 6 1 3 8 1 0 5 1 3 1 5 5 6 1 8 5 1 7 9 4 6 2 2 5 0 6 5 6 3 7 2 0 8 8 5 4 1 1 4 0 3 3 7 6 1 6 2
 [389] 1 9 2 8 6 1 9 5 2 5 4 4 2 8 3 8 2 4 5 0 3 1 7 7 5 7 9 7 1 9 2 1 4 2 9 2 0 4 9 1 4 8 1 8 4 5 9 8 8 3 7 6 0 0 3 0 2 0 6 9 9 3 3 3 2 3 9 1 2 6 8 0 5 6 6 6 3 8 8 2 7 5 8 9 6 1 8 4 1 2 5 9 1 9 7 5 4
 [486] 0 8 9 9 1 0 5 2 3 7 0 9 4 0 6 3 9 5 2 1 3 1 3 6 5 7 4 2 2 6 3 2 6 5 4 8 9 7 1 3 0 3 8 3 1 9 3 4 4 6 4 2 1 8 2 5 4 8 8 4 0 0 2 3 2 7 3 0 8 7 4 4 7 9 6 9 0 9 8 0 4 6 0 6 3 5 4 8 3 3 9 3 3 3 7 8 0
 [583] 2 2 1 7 0 6 5 4 3 8 0 9 6 3 8 0 9 9 6 8 6 8 5 7 8 6 0 2 4 0 2 2 3 1 9 7 5 2 0 8 4 6 2 6 7 9 3 2 9 8 2 2 9 2 7 3 5 9 1 8 0 2 0 5 2 1 3 7 6 7 1 2 5 8 0 3 7 2 4 0 9 1 8 6 7 7 4 3 4 9 1 9 5 1 7 3 9
 [680] 7 6 9 1 3 3 8 3 3 6 7 2 4 5 8 5 1 1 4 4 3 1 0 7 7 0 7 9 9 4 8 5 5 4 0 8 2 1 0 8 4 8 0 4 0 6 1 9 3 2 6 7 2 6 9 3 1 4 6 2 5 9 2 0 6 2 1 7 3 4 1 0 5 4 3 1 1 7 4 9 9 4 8 4 0 2 4 5 1 1 6 4 7 1 9 4 2
 [777] 4 1 5 5 3 8 3 1 4 5 6 8 9 4 1 5 3 8 0 3 2 5 1 2 8 3 4 4 0 8 8 3 3 1 2 3 5 9 6 3 2 6 1 3 6 0 7 2 1 7 1 4 2 4 2 1 7 9 6 1 1 2 4 3 1 7 7 4 8 0 7 3 1 3 1 0 7 7 0 3 5 5 2 7 6 6 9 2 8 3 5 2 2 5 6 0 8
 [874] 2 9 2 8 8 8 8 7 4 9 3 0 6 6 3 2 1 3 2 2 9 3 0 0 5 7 8 3 4 4 6 0 2 9 1 4 7 4 7 3 9 8 8 4 7 1 2 1 2 2 3 2 3 2 3 9 1 7 4 0 3 5 5 8 6 3 2 6 7 6 6 3 2 7 8 1 1 7 4 6 4 9 5 1 3 3 4 7 8 9 1 1 0 9 1 4 4
 [971] 5 4 0 6 2 2 3 1 5 1 2 0 3 8 1 2 6 7 1 6 2 3 9 0 1 2 2 0 8 9
 [ reached getOption("max.print") -- omitted 9000 entries ]
LS0tCnRpdGxlOiAiVGVzdGluZyBLZXJhcyBmb3IgUiIKb3V0cHV0OiBodG1sX25vdGVib29rCi0tLQpgYGB7cn0KbGlicmFyeShrZXJhcykKbW5pc3QgPC0gZGF0YXNldF9tbmlzdCgpCmBgYAoKYGBge3J9CnhfdHJhaW4gPC0gbW5pc3QkdHJhaW4keAp5X3RyYWluIDwtIG1uaXN0JHRyYWluJHkKeF90ZXN0IDwtIG1uaXN0JHRlc3QkeAp5X3Rlc3QgPC0gbW5pc3QkdGVzdCR5CmBgYAoKVGhlIHggZGF0YSBpcyBhIDMtZCBhcnJheSAoaW1hZ2VzLHdpZHRoLGhlaWdodCkgb2YgZ3JheXNjYWxlIHZhbHVlcyAuIFRvIHByZXBhcmUgdGhlIGRhdGEgZm9yIHRyYWluaW5nIHdlIGNvbnZlcnQgdGhlIDMtZCBhcnJheXMgaW50byBtYXRyaWNlcyBieSByZXNoYXBpbmcgd2lkdGggYW5kIGhlaWdodCBpbnRvIGEgc2luZ2xlIGRpbWVuc2lvbiAoMjh4MjggaW1hZ2VzIGFyZSBmbGF0dGVuZWQgaW50byBsZW5ndGggNzg0IHZlY3RvcnMpLiBUaGVuLCB3ZSBjb252ZXJ0IHRoZSBncmF5c2NhbGUgdmFsdWVzIGZyb20gaW50ZWdlcnMgcmFuZ2luZyBiZXR3ZWVuIDAgdG8gMjU1IGludG8gZmxvYXRpbmcgcG9pbnQgdmFsdWVzIHJhbmdpbmcgYmV0d2VlbiAwIGFuZCAxOgoKYGBge3J9CiMgcmVzaGFwZQpkaW0oeF90cmFpbikgPC0gYyhucm93KHhfdHJhaW4pLCA3ODQpCmRpbSh4X3Rlc3QpIDwtIGMobnJvdyh4X3Rlc3QpLCA3ODQpCiMgcmVzY2FsZQp4X3RyYWluIDwtIHhfdHJhaW4gLyAyNTUKeF90ZXN0IDwtIHhfdGVzdCAvIDI1NQpgYGAKClRoZSB5IGRhdGEgaXMgYW4gaW50ZWdlciB2ZWN0b3Igd2l0aCB2YWx1ZXMgcmFuZ2luZyBmcm9tIDAgdG8gOS4gVG8gcHJlcGFyZSB0aGlzIGRhdGEgZm9yIHRyYWluaW5nIHdlIG9uZS1ob3QgZW5jb2RlIHRoZSB2ZWN0b3JzIGludG8gYmluYXJ5IGNsYXNzIG1hdHJpY2VzIHVzaW5nIHRoZSBLZXJhcyBgdG9fY2F0ZWdvcmljYWwoKWAgZnVuY3Rpb246CgoKYGBge3J9CnlfdHJhaW4gPC0gdG9fY2F0ZWdvcmljYWwoeV90cmFpbiwgMTApCnlfdGVzdCA8LSB0b19jYXRlZ29yaWNhbCh5X3Rlc3QsIDEwKQpgYGAKCiNEZWZpbmluZyB0aGUgTW9kZWwKClRoZSBjb3JlIGRhdGEgc3RydWN0dXJlIG9mIEtlcmFzIGlzIGEgbW9kZWwsIGEgd2F5IHRvIG9yZ2FuaXplIGxheWVycy4gVGhlIHNpbXBsZXN0IHR5cGUgb2YgbW9kZWwgaXMgdGhlIFNlcXVlbnRpYWwgbW9kZWwsIGEgbGluZWFyIHN0YWNrIG9mIGxheWVycy4KCldlIGJlZ2luIGJ5IGNyZWF0aW5nIGEgc2VxdWVudGlhbCBtb2RlbCBhbmQgdGhlbiBhZGRpbmcgbGF5ZXJzIHVzaW5nIHRoZSBwaXBlIGAoJT4lKWAgb3BlcmF0b3I6CgpgYGB7cn0KbW9kZWwgPC0ga2VyYXNfbW9kZWxfc2VxdWVudGlhbCgpIAptb2RlbCAlPiUgCiAgbGF5ZXJfZGVuc2UodW5pdHMgPSAyNTYsIGFjdGl2YXRpb24gPSAncmVsdScsIGlucHV0X3NoYXBlID0gYyg3ODQpKSAlPiUgCiAgbGF5ZXJfZHJvcG91dChyYXRlID0gMC40KSAlPiUgCiAgbGF5ZXJfZGVuc2UodW5pdHMgPSAxMjgsIGFjdGl2YXRpb24gPSAncmVsdScpICU+JQogIGxheWVyX2Ryb3BvdXQocmF0ZSA9IDAuMykgJT4lCiAgbGF5ZXJfZGVuc2UodW5pdHMgPSAxMCwgYWN0aXZhdGlvbiA9ICdzb2Z0bWF4JykKYGBgCgpUaGUgaW5wdXRfc2hhcGUgYXJndW1lbnQgdG8gdGhlIGZpcnN0IGxheWVyIHNwZWNpZmllcyB0aGUgc2hhcGUgb2YgdGhlIGlucHV0IGRhdGEgKGEgbGVuZ3RoIDc4NCBudW1lcmljIHZlY3RvciByZXByZXNlbnRpbmcgYSBncmF5c2NhbGUgaW1hZ2UpLiBUaGUgZmluYWwgbGF5ZXIgb3V0cHV0cyBhIGxlbmd0aCAxMCBudW1lcmljIHZlY3RvciAocHJvYmFiaWxpdGllcyBmb3IgZWFjaCBkaWdpdCkgdXNpbmcgYSBzb2Z0bWF4IGFjdGl2YXRpb24gZnVuY3Rpb24uCgpVc2UgdGhlIGBzdW1tYXJ5KClgIGZ1bmN0aW9uIHRvIHByaW50IHRoZSBkZXRhaWxzIG9mIHRoZSBtb2RlbDoKYGBge3J9CnN1bW1hcnkobW9kZWwpCmBgYAoKTmV4dCwgY29tcGlsZSB0aGUgbW9kZWwgd2l0aCBhcHByb3ByaWF0ZSBsb3NzIGZ1bmN0aW9uLCBvcHRpbWl6ZXIsIGFuZCBtZXRyaWNzOgoKCgpgYGB7cn0KbW9kZWwgJT4lIGNvbXBpbGUoCiAgbG9zcyA9ICdjYXRlZ29yaWNhbF9jcm9zc2VudHJvcHknLAogIG9wdGltaXplciA9IG9wdGltaXplcl9ybXNwcm9wKCksCiAgbWV0cmljcyA9IGMoJ2FjY3VyYWN5JykKKQoKYGBgCgojVHJhaW5pbmcgYW5kIEV2YWx1YXRpb24KClVzZSB0aGUgYGZpdCgpYCBmdW5jdGlvbiB0byB0cmFpbiB0aGUgbW9kZWwgZm9yIDMwIGVwb2NocyB1c2luZyBiYXRjaGVzIG9mIDEyOCBpbWFnZXM6CmBgYHtyfQoKaGlzdG9yeSA8LSBtb2RlbCAlPiUgZml0KAogIHhfdHJhaW4sIHlfdHJhaW4sIAogIGVwb2NocyA9IDMwLCBiYXRjaF9zaXplID0gMTI4LCAKICB2YWxpZGF0aW9uX3NwbGl0ID0gMC4yCikKYGBgCgpUaGUgYGhpc3RvcnlgIG9iamVjdCByZXR1cm5lZCBieSBgZml0KClgIGluY2x1ZGVzIGxvc3MgYW5kIGFjY3VyYWN5IG1ldHJpY3Mgd2hpY2ggd2UgY2FuIHBsb3Q6CgpgYGB7cn0KcGxvdChoaXN0b3J5KQoKYGBgCkV2YWx1YXRlIHRoZSBtb2RlbOKAmXMgcGVyZm9ybWFuY2Ugb24gdGhlIHRlc3QgZGF0YToKCgoKYGBge3J9Cm1vZGVsICU+JSBldmFsdWF0ZSh4X3Rlc3QsIHlfdGVzdCkKYGBgCkdlbmVyYXRlIHByZWRpY3Rpb25zIG9uIG5ldyBkYXRhCmBgYHtyfQptb2RlbCAlPiUgcHJlZGljdF9jbGFzc2VzKHhfdGVzdCkKYGBgCgo=