library(keras)
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y
The x data is a 3-d array (images,width,height) of grayscale values . To prepare the data for training we convert the 3-d arrays into matrices by reshaping width and height into a single dimension (28x28 images are flattened into length 784 vectors). Then, we convert the grayscale values from integers ranging between 0 to 255 into floating point values ranging between 0 and 1:
# reshape
dim(x_train) <- c(nrow(x_train), 784)
dim(x_test) <- c(nrow(x_test), 784)
# rescale
x_train <- x_train / 255
x_test <- x_test / 255
The y data is an integer vector with values ranging from 0 to 9. To prepare this data for training we one-hot encode the vectors into binary class matrices using the Keras to_categorical()
function:
y_train <- to_categorical(y_train, 10)
y_test <- to_categorical(y_test, 10)
The core data structure of Keras is a model, a way to organize layers. The simplest type of model is the Sequential model, a linear stack of layers.
We begin by creating a sequential model and then adding layers using the pipe (%>%)
operator:
model <- keras_model_sequential()
model %>%
layer_dense(units = 256, activation = 'relu', input_shape = c(784)) %>%
layer_dropout(rate = 0.4) %>%
layer_dense(units = 128, activation = 'relu') %>%
layer_dropout(rate = 0.3) %>%
layer_dense(units = 10, activation = 'softmax')
The input_shape argument to the first layer specifies the shape of the input data (a length 784 numeric vector representing a grayscale image). The final layer outputs a length 10 numeric vector (probabilities for each digit) using a softmax activation function.
Use the summary()
function to print the details of the model:
summary(model)
________________________________________________________________________________________________________________________________________________________________________________________________________
Layer (type) Output Shape Param #
========================================================================================================================================================================================================
dense_7 (Dense) (None, 256) 200960
________________________________________________________________________________________________________________________________________________________________________________________________________
dropout_5 (Dropout) (None, 256) 0
________________________________________________________________________________________________________________________________________________________________________________________________________
dense_8 (Dense) (None, 128) 32896
________________________________________________________________________________________________________________________________________________________________________________________________________
dropout_6 (Dropout) (None, 128) 0
________________________________________________________________________________________________________________________________________________________________________________________________________
dense_9 (Dense) (None, 10) 1290
========================================================================================================================================================================================================
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0
________________________________________________________________________________________________________________________________________________________________________________________________________
Next, compile the model with appropriate loss function, optimizer, and metrics:
model %>% compile(
loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(),
metrics = c('accuracy')
)
Use the fit()
function to train the model for 30 epochs using batches of 128 images:
history <- model %>% fit(
x_train, y_train,
epochs = 30, batch_size = 128,
validation_split = 0.2
)
The history
object returned by fit()
includes loss and accuracy metrics which we can plot:
plot(history)
Evaluate the model’s performance on the test data:
model %>% evaluate(x_test, y_test)
32/10000 [..............................] - ETA: 3s
736/10000 [=>............................] - ETA: 0s
1440/10000 [===>..........................] - ETA: 0s
2176/10000 [=====>........................] - ETA: 0s
2944/10000 [=======>......................] - ETA: 0s
3648/10000 [=========>....................] - ETA: 0s
4320/10000 [===========>..................] - ETA: 0s
5056/10000 [==============>...............] - ETA: 0s
5824/10000 [================>.............] - ETA: 0s
6528/10000 [==================>...........] - ETA: 0s
7264/10000 [====================>.........] - ETA: 0s
7968/10000 [======================>.......] - ETA: 0s
8640/10000 [========================>.....] - ETA: 0s
9344/10000 [===========================>..] - ETA: 0s
10000/10000 [==============================] - 0s
$loss
[1] 0.1094651
$acc
[1] 0.9785
Generate predictions on new data
model %>% predict_classes(x_test)
[1] 7 2 1 0 4 1 4 9 6 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3 4 7 2 7 1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 9 3 7 4 6 4 3 0 7 0 2 9 1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1
[98] 7 6 9 6 0 5 4 9 9 2 1 9 4 8 7 3 9 7 4 4 4 9 2 5 4 7 6 7 9 0 5 8 5 6 6 5 7 8 1 0 1 6 4 6 7 3 1 7 1 8 2 0 3 9 9 5 5 1 5 6 0 3 4 4 6 5 4 6 5 4 5 1 4 4 7 2 3 2 7 1 8 1 8 1 8 5 0 8 9 2 5 0 1 1 1 0 9
[195] 0 3 1 6 4 2 3 6 1 1 1 3 9 5 2 9 4 5 9 3 9 0 3 6 5 5 7 2 2 7 1 2 8 4 1 7 3 3 8 8 7 9 2 2 4 1 5 9 8 7 2 3 0 6 4 2 4 1 9 5 7 7 2 8 2 0 8 5 7 7 9 1 8 1 8 0 3 0 1 9 9 4 1 8 2 1 2 9 7 5 9 2 6 4 1 5 8
[292] 2 9 2 0 4 0 0 2 8 4 7 1 2 4 0 2 7 4 3 3 0 0 3 1 9 6 5 2 5 9 7 9 3 0 4 2 0 7 1 1 2 1 5 3 3 9 7 8 6 3 6 1 3 8 1 0 5 1 3 1 5 5 6 1 8 5 1 7 9 4 6 2 2 5 0 6 5 6 3 7 2 0 8 8 5 4 1 1 4 0 3 3 7 6 1 6 2
[389] 1 9 2 8 6 1 9 5 2 5 4 4 2 8 3 8 2 4 5 0 3 1 7 7 5 7 9 7 1 9 2 1 4 2 9 2 0 4 9 1 4 8 1 8 4 5 9 8 8 3 7 6 0 0 3 0 2 0 6 9 9 3 3 3 2 3 9 1 2 6 8 0 5 6 6 6 3 8 8 2 7 5 8 9 6 1 8 4 1 2 5 9 1 9 7 5 4
[486] 0 8 9 9 1 0 5 2 3 7 0 9 4 0 6 3 9 5 2 1 3 1 3 6 5 7 4 2 2 6 3 2 6 5 4 8 9 7 1 3 0 3 8 3 1 9 3 4 4 6 4 2 1 8 2 5 4 8 8 4 0 0 2 3 2 7 3 0 8 7 4 4 7 9 6 9 0 9 8 0 4 6 0 6 3 5 4 8 3 3 9 3 3 3 7 8 0
[583] 2 2 1 7 0 6 5 4 3 8 0 9 6 3 8 0 9 9 6 8 6 8 5 7 8 6 0 2 4 0 2 2 3 1 9 7 5 2 0 8 4 6 2 6 7 9 3 2 9 8 2 2 9 2 7 3 5 9 1 8 0 2 0 5 2 1 3 7 6 7 1 2 5 8 0 3 7 2 4 0 9 1 8 6 7 7 4 3 4 9 1 9 5 1 7 3 9
[680] 7 6 9 1 3 3 8 3 3 6 7 2 4 5 8 5 1 1 4 4 3 1 0 7 7 0 7 9 9 4 8 5 5 4 0 8 2 1 0 8 4 8 0 4 0 6 1 9 3 2 6 7 2 6 9 3 1 4 6 2 5 9 2 0 6 2 1 7 3 4 1 0 5 4 3 1 1 7 4 9 9 4 8 4 0 2 4 5 1 1 6 4 7 1 9 4 2
[777] 4 1 5 5 3 8 3 1 4 5 6 8 9 4 1 5 3 8 0 3 2 5 1 2 8 3 4 4 0 8 8 3 3 1 2 3 5 9 6 3 2 6 1 3 6 0 7 2 1 7 1 4 2 4 2 1 7 9 6 1 1 2 4 3 1 7 7 4 8 0 7 3 1 3 1 0 7 7 0 3 5 5 2 7 6 6 9 2 8 3 5 2 2 5 6 0 8
[874] 2 9 2 8 8 8 8 7 4 9 3 0 6 6 3 2 1 3 2 2 9 3 0 0 5 7 8 3 4 4 6 0 2 9 1 4 7 4 7 3 9 8 8 4 7 1 2 1 2 2 3 2 3 2 3 9 1 7 4 0 3 5 5 8 6 3 2 6 7 6 6 3 2 7 8 1 1 7 4 6 4 9 5 1 3 3 4 7 8 9 1 1 0 9 1 4 4
[971] 5 4 0 6 2 2 3 1 5 1 2 0 3 8 1 2 6 7 1 6 2 3 9 0 1 2 2 0 8 9
[ reached getOption("max.print") -- omitted 9000 entries ]