Install the package keras and setup the keras environment

install.packages('keras')
keras::install_keras(tensorflow = '1.12')

Loading the keras package

library(keras)

Simple dense ANN example

Load data

library(keras)
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y

# reshape
dim(x_train) <- c(nrow(x_train), 784)
dim(x_test) <- c(nrow(x_test), 784)
# rescale
x_train <- x_train / 255
x_test <- x_test / 255

y_train <- to_categorical(y_train, 10)
y_test <- to_categorical(y_test, 10)

Initialize the ANN structure

model <- keras_model_sequential() 
layer_dense(model, units = 256, activation = "relu", input_shape = c(784)) 
layer_dropout(model, rate = 0.4) 
layer_dense(model, units = 128, activation = "relu") 
#layer_dropout(model, rate = 0.3)
layer_dense(model, units = 10, activation = "softmax")

summary(model)
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## dense (Dense)                    (None, 256)                   200960      
## ___________________________________________________________________________
## dropout (Dropout)                (None, 256)                   0           
## ___________________________________________________________________________
## dense_1 (Dense)                  (None, 128)                   32896       
## ___________________________________________________________________________
## dense_2 (Dense)                  (None, 10)                    1290        
## ===========================================================================
## Total params: 235,146
## Trainable params: 235,146
## Non-trainable params: 0
## ___________________________________________________________________________

Configuring the complilation step, e.g. setting the loss function, the optimizer… loss is the loss function, here we chose the 'categorical_crossentropy' optimizer is the optimizer we choose to train the model, there are lots of options in

model %>% compile(
  loss = "categorical_crossentropy",
  optimizer = optimizer_adam(),
  metrics = c("accuracy")
)

Training model using validation for stopping training to avoid overfitting

batch_size is the number of samples per gradient update, the popular batch sizes include 32, 64, and 128; epochs is the number of times that the learning algorithem will work through the entire training dataset;

history <- model %>% fit(
  x_train, y_train, 
  epochs = 10, batch_size = 64, 
  validation_split = 0.2
)

plot(history)

The classification accuracy on testing dataset

results <- evaluate(model, x_test, y_test)

The classification accuracy is 0.9792 for the first deep feedforward nets.

Small convoluntionary net

Loading data

mnist <- dataset_mnist() 
c(c(train_images, train_labels), c(test_images, test_labels)) %<-% mnist

train_images <- array_reshape(train_images, c(60000, 28, 28, 1)) 
train_images <- train_images / 255

test_images <- array_reshape(test_images, c(10000, 28, 28, 1)) 
test_images <- test_images / 255

train_labels <- to_categorical(train_labels) 
test_labels <- to_categorical(test_labels)

Initialize the structure of the small CNN for classification, which contained 3 convolution layers, two maxpooling layers, and two dense layers. This part is to initialize the CNN extracts features from image (3 conv layers and 2 maxpooling layers)

filters is the number of output filters in the convolution kernel_size is the width and height of the 2D convolution woindow

model <- keras_model_sequential() 
layer_conv_2d(model, filters = 32, kernel_size = c(3, 3), activation = "relu", 
              input_shape = c(28, 28, 1))
layer_max_pooling_2d(model, pool_size = c(2, 2)) 
layer_conv_2d(model, filters = 64, kernel_size = c(3, 3), activation = "relu") 
layer_max_pooling_2d(model, pool_size = c(2, 2)) 
layer_conv_2d(model, filters = 64, kernel_size = c(3, 3), activation = "relu")

This part is to add classifier on top of the CNN (2 dense layers)

layer_flatten(model)
layer_dense(model, units = 64, activation = "relu") 
layer_dense(model, units = 10, activation = 'softmax')

Setting model parameters

compile(model, optimizer = 'adam',
        loss = "categorical_crossentropy",
        metrics = c("accuracy") )
summary(model)
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## conv2d (Conv2D)                  (None, 26, 26, 32)            320         
## ___________________________________________________________________________
## max_pooling2d (MaxPooling2D)     (None, 13, 13, 32)            0           
## ___________________________________________________________________________
## conv2d_1 (Conv2D)                (None, 11, 11, 64)            18496       
## ___________________________________________________________________________
## max_pooling2d_1 (MaxPooling2D)   (None, 5, 5, 64)              0           
## ___________________________________________________________________________
## conv2d_2 (Conv2D)                (None, 3, 3, 64)              36928       
## ___________________________________________________________________________
## flatten (Flatten)                (None, 576)                   0           
## ___________________________________________________________________________
## dense_3 (Dense)                  (None, 64)                    36928       
## ___________________________________________________________________________
## dense_4 (Dense)                  (None, 10)                    650         
## ===========================================================================
## Total params: 93,322
## Trainable params: 93,322
## Non-trainable params: 0
## ___________________________________________________________________________

Traing models

start_time <- Sys.time()
history <- fit(model,train_images, 
       train_labels, epochs = 10, 
       batch_size=64, validation_split = 0.2)
training_time <-  Sys.time() - start_time
plot(history)

Testing models

results <- evaluate(model, test_images, test_labels)

The classification accuracy is 0.9919 for the simple CNN. Time consuming is 177.2239001 seconds.

How to get the parameters model$get_weights()

summary(model)
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## conv2d (Conv2D)                  (None, 26, 26, 32)            320         
## ___________________________________________________________________________
## max_pooling2d (MaxPooling2D)     (None, 13, 13, 32)            0           
## ___________________________________________________________________________
## conv2d_1 (Conv2D)                (None, 11, 11, 64)            18496       
## ___________________________________________________________________________
## max_pooling2d_1 (MaxPooling2D)   (None, 5, 5, 64)              0           
## ___________________________________________________________________________
## conv2d_2 (Conv2D)                (None, 3, 3, 64)              36928       
## ___________________________________________________________________________
## flatten (Flatten)                (None, 576)                   0           
## ___________________________________________________________________________
## dense_3 (Dense)                  (None, 64)                    36928       
## ___________________________________________________________________________
## dense_4 (Dense)                  (None, 10)                    650         
## ===========================================================================
## Total params: 93,322
## Trainable params: 93,322
## Non-trainable params: 0
## ___________________________________________________________________________
cnn_filter <- model$get_weights()[[1]]
cnn_filter
## , , 1, 1
## 
##            [,1]        [,2]       [,3]
## [1,]  0.1841683  0.11331795  0.2710708
## [2,]  0.1267301  0.05273076 -0.1137537
## [3,] -0.1574383 -0.31324214 -0.2109357
## 
## , , 1, 2
## 
##            [,1]        [,2]       [,3]
## [1,] 0.22965741 -0.02441255 -0.2239949
## [2,] 0.23970620 -0.03335947 -0.2931992
## [3,] 0.05471641 -0.19844154 -0.2709639
## 
## , , 1, 3
## 
##           [,1]      [,2]       [,3]
## [1,] 0.1277004 0.1556847 -0.1649468
## [2,] 0.1521157 0.0530666 -0.1987787
## [3,] 0.1334760 0.0234578 -0.2350789
## 
## , , 1, 4
## 
##           [,1]       [,2]        [,3]
## [1,] 0.1424268 -0.1017003 -0.15652415
## [2,] 0.1230620 -0.1150752 -0.33047736
## [3,] 0.3613366  0.0458409 -0.02334702
## 
## , , 1, 5
## 
##             [,1]       [,2]        [,3]
## [1,]  0.07792772 0.10102654 -0.22467726
## [2,]  0.10278948 0.25653377  0.06007000
## [3,] -0.26681778 0.05844313  0.07538705
## 
## , , 1, 6
## 
##             [,1]       [,2]        [,3]
## [1,] -0.12720755 0.01505563 -0.30181950
## [2,]  0.09151746 0.20518117 -0.08951084
## [3,] -0.12117168 0.09558871  0.19717427
## 
## , , 1, 7
## 
##            [,1]        [,2]        [,3]
## [1,]  0.2123220 -0.39900032 -0.07959107
## [2,] -0.1399887 -0.30633488  0.20100906
## [3,] -0.3191103 -0.07501857  0.19341511
## 
## , , 1, 8
## 
##              [,1]        [,2]       [,3]
## [1,] -0.091615714 -0.20263189 0.12467928
## [2,] -0.264670879  0.09833495 0.19120076
## [3,]  0.003032753  0.08403318 0.03388591
## 
## , , 1, 9
## 
##             [,1]        [,2]        [,3]
## [1,]  0.32390502  0.14198238 -0.20695715
## [2,] -0.04360836 -0.24334477 -0.21066180
## [3,] -0.36657792 -0.05387761 -0.05092364
## 
## , , 1, 10
## 
##            [,1]        [,2]        [,3]
## [1,]  0.2717163  0.03925578 -0.18874428
## [2,] -0.2442554 -0.37970689  0.08406924
## [3,] -0.1587704  0.22025964  0.14980406
## 
## , , 1, 11
## 
##             [,1]       [,2]        [,3]
## [1,]  0.05034082  0.1275936  0.22875997
## [2,]  0.18511258  0.1039342  0.04833355
## [3,] -0.17168005 -0.3089989 -0.27627045
## 
## , , 1, 12
## 
##            [,1]       [,2]        [,3]
## [1,]  0.1259424  0.1766799  0.05804553
## [2,] -0.0860180  0.1303663  0.19702147
## [3,] -0.3527958 -0.2472924 -0.06472023
## 
## , , 1, 13
## 
##             [,1]         [,2]        [,3]
## [1,]  0.04744837 0.1184724122 -0.15636331
## [2,] -0.09501453 0.3082179129  0.16113307
## [3,] -0.39738721 0.0002436367 -0.02059797
## 
## , , 1, 14
## 
##             [,1]        [,2]       [,3]
## [1,] -0.09275295 -0.17898135 0.03981522
## [2,] -0.02026006  0.07267290 0.04671063
## [3,]  0.12693621  0.09999003 0.06377400
## 
## , , 1, 15
## 
##             [,1]        [,2]       [,3]
## [1,] -0.02662293  0.15563160 -0.2069693
## [2,]  0.16036426 -0.06659439 -0.2490396
## [3,]  0.13115890 -0.02441042  0.1691620
## 
## , , 1, 16
## 
##             [,1]         [,2]        [,3]
## [1,] 0.003019769 -0.005169238  0.07612002
## [2,] 0.136856481  0.125336140  0.18249507
## [3,] 0.067267701 -0.110435426 -0.17970800
## 
## , , 1, 17
## 
##             [,1]         [,2]         [,3]
## [1,]  0.27913409  0.069171369 -0.084366485
## [2,]  0.12596641 -0.007582621 -0.395501554
## [3,] -0.09593075 -0.439188570 -0.005470724
## 
## , , 1, 18
## 
##             [,1]        [,2]       [,3]
## [1,]  0.03109886 -0.30162510 -0.3466071
## [2,] -0.28173411 -0.24839474  0.2117487
## [3,] -0.32404360 -0.01638841  0.1819495
## 
## , , 1, 19
## 
##              [,1]        [,2]        [,3]
## [1,] -0.324461520 -0.37082568 -0.08539703
## [2,]  0.007754751 -0.02006943  0.04113802
## [3,]  0.272813588  0.29671478  0.13505119
## 
## , , 1, 20
## 
##            [,1]        [,2]        [,3]
## [1,]  0.1791219  0.13492802  0.19200100
## [2,] -0.1543401  0.03824738  0.06823149
## [3,] -0.1522459 -0.29696718 -0.09996627
## 
## , , 1, 21
## 
##            [,1]       [,2]       [,3]
## [1,] 0.01572081 0.21929409 -0.2354469
## [2,] 0.10163661 0.05239917 -0.1912061
## [3,] 0.11799429 0.13303457 -0.1710693
## 
## , , 1, 22
## 
##             [,1]       [,2]        [,3]
## [1,] -0.12968151 -0.2553517 -0.04766534
## [2,] -0.06058308 -0.3328483  0.28604093
## [3,] -0.30992079 -0.1711237  0.06773672
## 
## , , 1, 23
## 
##            [,1]         [,2]        [,3]
## [1,]  0.2209622  0.002531538 -0.26543039
## [2,]  0.1453079 -0.265582532 -0.21101233
## [3,] -0.1092457 -0.143584728  0.01495782
## 
## , , 1, 24
## 
##            [,1]        [,2]      [,3]
## [1,] -0.2968422 -0.02486045 0.3005786
## [2,] -0.2499849 -0.12577796 0.3150137
## [3,] -0.2217832 -0.13249880 0.1498997
## 
## , , 1, 25
## 
##            [,1]        [,2]       [,3]
## [1,] -0.2376153 -0.02401282 -0.0495364
## [2,] -0.1005466  0.16798191  0.1943555
## [3,]  0.0953801  0.15950544 -0.1368785
## 
## , , 1, 26
## 
##            [,1]        [,2]       [,3]
## [1,] 0.10566790 -0.10734998 0.10212740
## [2,] 0.06916958  0.04824210 0.02064957
## [3,] 0.10526152  0.02690731 0.10656796
## 
## , , 1, 27
## 
##            [,1]      [,2]       [,3]
## [1,] 0.08560913 0.1169470 -0.1710785
## [2,] 0.11514548 0.1530292 -0.1579792
## [3,] 0.02983252 0.1308075 -0.1988781
## 
## , , 1, 28
## 
##             [,1]        [,2]       [,3]
## [1,]  0.01515060 -0.06721498 -0.1894968
## [2,] -0.02067504  0.21945311 -0.2427155
## [3,] -0.12837163  0.16743895  0.1641244
## 
## , , 1, 29
## 
##            [,1]        [,2]        [,3]
## [1,]  0.1465970  0.08220476 -0.30966446
## [2,]  0.3616001 -0.16875626 -0.29371810
## [3,] -0.1944017 -0.32792708 -0.04309321
## 
## , , 1, 30
## 
##            [,1]       [,2]        [,3]
## [1,] 0.16257612 0.18305269  0.08303798
## [2,] 0.07917408 0.14489332 -0.02815112
## [3,] 0.02105355 0.05869713  0.09788346
## 
## , , 1, 31
## 
##            [,1]        [,2]       [,3]
## [1,]  0.1871261  0.25106141  0.1206731
## [2,] -0.1784246 -0.21292765 -0.1594458
## [3,] -0.1832155 -0.02868943  0.1094614
## 
## , , 1, 32
## 
##             [,1]       [,2]       [,3]
## [1,] -0.24961968 -0.1156288 0.19642599
## [2,] -0.06554846  0.2535471 0.08002317
## [3,]  0.09769563 -0.1490398 0.06434184