keras and setup the keras environmentinstall.packages('keras')
keras::install_keras(tensorflow = '1.12')
library(keras)
Load data
library(keras)
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y
# reshape
dim(x_train) <- c(nrow(x_train), 784)
dim(x_test) <- c(nrow(x_test), 784)
# rescale
x_train <- x_train / 255
x_test <- x_test / 255
y_train <- to_categorical(y_train, 10)
y_test <- to_categorical(y_test, 10)
Initialize the ANN structure
model <- keras_model_sequential()
layer_dense(model, units = 256, activation = "relu", input_shape = c(784))
layer_dropout(model, rate = 0.4)
layer_dense(model, units = 128, activation = "relu")
#layer_dropout(model, rate = 0.3)
layer_dense(model, units = 10, activation = "softmax")
summary(model)
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## dense (Dense) (None, 256) 200960
## ___________________________________________________________________________
## dropout (Dropout) (None, 256) 0
## ___________________________________________________________________________
## dense_1 (Dense) (None, 128) 32896
## ___________________________________________________________________________
## dense_2 (Dense) (None, 10) 1290
## ===========================================================================
## Total params: 235,146
## Trainable params: 235,146
## Non-trainable params: 0
## ___________________________________________________________________________
Configuring the complilation step, e.g. setting the loss function, the optimizer… loss is the loss function, here we chose the 'categorical_crossentropy' optimizer is the optimizer we choose to train the model, there are lots of options in
model %>% compile(
loss = "categorical_crossentropy",
optimizer = optimizer_adam(),
metrics = c("accuracy")
)
Training model using validation for stopping training to avoid overfitting
batch_size is the number of samples per gradient update, the popular batch sizes include 32, 64, and 128; epochs is the number of times that the learning algorithem will work through the entire training dataset;
history <- model %>% fit(
x_train, y_train,
epochs = 10, batch_size = 64,
validation_split = 0.2
)
plot(history)
The classification accuracy on testing dataset
results <- evaluate(model, x_test, y_test)
The classification accuracy is 0.9792 for the first deep feedforward nets.
Loading data
mnist <- dataset_mnist()
c(c(train_images, train_labels), c(test_images, test_labels)) %<-% mnist
train_images <- array_reshape(train_images, c(60000, 28, 28, 1))
train_images <- train_images / 255
test_images <- array_reshape(test_images, c(10000, 28, 28, 1))
test_images <- test_images / 255
train_labels <- to_categorical(train_labels)
test_labels <- to_categorical(test_labels)
Initialize the structure of the small CNN for classification, which contained 3 convolution layers, two maxpooling layers, and two dense layers. This part is to initialize the CNN extracts features from image (3 conv layers and 2 maxpooling layers)
filters is the number of output filters in the convolution kernel_size is the width and height of the 2D convolution woindow
model <- keras_model_sequential()
layer_conv_2d(model, filters = 32, kernel_size = c(3, 3), activation = "relu",
input_shape = c(28, 28, 1))
layer_max_pooling_2d(model, pool_size = c(2, 2))
layer_conv_2d(model, filters = 64, kernel_size = c(3, 3), activation = "relu")
layer_max_pooling_2d(model, pool_size = c(2, 2))
layer_conv_2d(model, filters = 64, kernel_size = c(3, 3), activation = "relu")
This part is to add classifier on top of the CNN (2 dense layers)
layer_flatten(model)
layer_dense(model, units = 64, activation = "relu")
layer_dense(model, units = 10, activation = 'softmax')
Setting model parameters
compile(model, optimizer = 'adam',
loss = "categorical_crossentropy",
metrics = c("accuracy") )
summary(model)
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## conv2d (Conv2D) (None, 26, 26, 32) 320
## ___________________________________________________________________________
## max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0
## ___________________________________________________________________________
## conv2d_1 (Conv2D) (None, 11, 11, 64) 18496
## ___________________________________________________________________________
## max_pooling2d_1 (MaxPooling2D) (None, 5, 5, 64) 0
## ___________________________________________________________________________
## conv2d_2 (Conv2D) (None, 3, 3, 64) 36928
## ___________________________________________________________________________
## flatten (Flatten) (None, 576) 0
## ___________________________________________________________________________
## dense_3 (Dense) (None, 64) 36928
## ___________________________________________________________________________
## dense_4 (Dense) (None, 10) 650
## ===========================================================================
## Total params: 93,322
## Trainable params: 93,322
## Non-trainable params: 0
## ___________________________________________________________________________
Traing models
start_time <- Sys.time()
history <- fit(model,train_images,
train_labels, epochs = 10,
batch_size=64, validation_split = 0.2)
training_time <- Sys.time() - start_time
plot(history)
Testing models
results <- evaluate(model, test_images, test_labels)
The classification accuracy is 0.9919 for the simple CNN. Time consuming is 177.2239001 seconds.
How to get the parameters model$get_weights()
summary(model)
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## conv2d (Conv2D) (None, 26, 26, 32) 320
## ___________________________________________________________________________
## max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0
## ___________________________________________________________________________
## conv2d_1 (Conv2D) (None, 11, 11, 64) 18496
## ___________________________________________________________________________
## max_pooling2d_1 (MaxPooling2D) (None, 5, 5, 64) 0
## ___________________________________________________________________________
## conv2d_2 (Conv2D) (None, 3, 3, 64) 36928
## ___________________________________________________________________________
## flatten (Flatten) (None, 576) 0
## ___________________________________________________________________________
## dense_3 (Dense) (None, 64) 36928
## ___________________________________________________________________________
## dense_4 (Dense) (None, 10) 650
## ===========================================================================
## Total params: 93,322
## Trainable params: 93,322
## Non-trainable params: 0
## ___________________________________________________________________________
cnn_filter <- model$get_weights()[[1]]
cnn_filter
## , , 1, 1
##
## [,1] [,2] [,3]
## [1,] 0.1841683 0.11331795 0.2710708
## [2,] 0.1267301 0.05273076 -0.1137537
## [3,] -0.1574383 -0.31324214 -0.2109357
##
## , , 1, 2
##
## [,1] [,2] [,3]
## [1,] 0.22965741 -0.02441255 -0.2239949
## [2,] 0.23970620 -0.03335947 -0.2931992
## [3,] 0.05471641 -0.19844154 -0.2709639
##
## , , 1, 3
##
## [,1] [,2] [,3]
## [1,] 0.1277004 0.1556847 -0.1649468
## [2,] 0.1521157 0.0530666 -0.1987787
## [3,] 0.1334760 0.0234578 -0.2350789
##
## , , 1, 4
##
## [,1] [,2] [,3]
## [1,] 0.1424268 -0.1017003 -0.15652415
## [2,] 0.1230620 -0.1150752 -0.33047736
## [3,] 0.3613366 0.0458409 -0.02334702
##
## , , 1, 5
##
## [,1] [,2] [,3]
## [1,] 0.07792772 0.10102654 -0.22467726
## [2,] 0.10278948 0.25653377 0.06007000
## [3,] -0.26681778 0.05844313 0.07538705
##
## , , 1, 6
##
## [,1] [,2] [,3]
## [1,] -0.12720755 0.01505563 -0.30181950
## [2,] 0.09151746 0.20518117 -0.08951084
## [3,] -0.12117168 0.09558871 0.19717427
##
## , , 1, 7
##
## [,1] [,2] [,3]
## [1,] 0.2123220 -0.39900032 -0.07959107
## [2,] -0.1399887 -0.30633488 0.20100906
## [3,] -0.3191103 -0.07501857 0.19341511
##
## , , 1, 8
##
## [,1] [,2] [,3]
## [1,] -0.091615714 -0.20263189 0.12467928
## [2,] -0.264670879 0.09833495 0.19120076
## [3,] 0.003032753 0.08403318 0.03388591
##
## , , 1, 9
##
## [,1] [,2] [,3]
## [1,] 0.32390502 0.14198238 -0.20695715
## [2,] -0.04360836 -0.24334477 -0.21066180
## [3,] -0.36657792 -0.05387761 -0.05092364
##
## , , 1, 10
##
## [,1] [,2] [,3]
## [1,] 0.2717163 0.03925578 -0.18874428
## [2,] -0.2442554 -0.37970689 0.08406924
## [3,] -0.1587704 0.22025964 0.14980406
##
## , , 1, 11
##
## [,1] [,2] [,3]
## [1,] 0.05034082 0.1275936 0.22875997
## [2,] 0.18511258 0.1039342 0.04833355
## [3,] -0.17168005 -0.3089989 -0.27627045
##
## , , 1, 12
##
## [,1] [,2] [,3]
## [1,] 0.1259424 0.1766799 0.05804553
## [2,] -0.0860180 0.1303663 0.19702147
## [3,] -0.3527958 -0.2472924 -0.06472023
##
## , , 1, 13
##
## [,1] [,2] [,3]
## [1,] 0.04744837 0.1184724122 -0.15636331
## [2,] -0.09501453 0.3082179129 0.16113307
## [3,] -0.39738721 0.0002436367 -0.02059797
##
## , , 1, 14
##
## [,1] [,2] [,3]
## [1,] -0.09275295 -0.17898135 0.03981522
## [2,] -0.02026006 0.07267290 0.04671063
## [3,] 0.12693621 0.09999003 0.06377400
##
## , , 1, 15
##
## [,1] [,2] [,3]
## [1,] -0.02662293 0.15563160 -0.2069693
## [2,] 0.16036426 -0.06659439 -0.2490396
## [3,] 0.13115890 -0.02441042 0.1691620
##
## , , 1, 16
##
## [,1] [,2] [,3]
## [1,] 0.003019769 -0.005169238 0.07612002
## [2,] 0.136856481 0.125336140 0.18249507
## [3,] 0.067267701 -0.110435426 -0.17970800
##
## , , 1, 17
##
## [,1] [,2] [,3]
## [1,] 0.27913409 0.069171369 -0.084366485
## [2,] 0.12596641 -0.007582621 -0.395501554
## [3,] -0.09593075 -0.439188570 -0.005470724
##
## , , 1, 18
##
## [,1] [,2] [,3]
## [1,] 0.03109886 -0.30162510 -0.3466071
## [2,] -0.28173411 -0.24839474 0.2117487
## [3,] -0.32404360 -0.01638841 0.1819495
##
## , , 1, 19
##
## [,1] [,2] [,3]
## [1,] -0.324461520 -0.37082568 -0.08539703
## [2,] 0.007754751 -0.02006943 0.04113802
## [3,] 0.272813588 0.29671478 0.13505119
##
## , , 1, 20
##
## [,1] [,2] [,3]
## [1,] 0.1791219 0.13492802 0.19200100
## [2,] -0.1543401 0.03824738 0.06823149
## [3,] -0.1522459 -0.29696718 -0.09996627
##
## , , 1, 21
##
## [,1] [,2] [,3]
## [1,] 0.01572081 0.21929409 -0.2354469
## [2,] 0.10163661 0.05239917 -0.1912061
## [3,] 0.11799429 0.13303457 -0.1710693
##
## , , 1, 22
##
## [,1] [,2] [,3]
## [1,] -0.12968151 -0.2553517 -0.04766534
## [2,] -0.06058308 -0.3328483 0.28604093
## [3,] -0.30992079 -0.1711237 0.06773672
##
## , , 1, 23
##
## [,1] [,2] [,3]
## [1,] 0.2209622 0.002531538 -0.26543039
## [2,] 0.1453079 -0.265582532 -0.21101233
## [3,] -0.1092457 -0.143584728 0.01495782
##
## , , 1, 24
##
## [,1] [,2] [,3]
## [1,] -0.2968422 -0.02486045 0.3005786
## [2,] -0.2499849 -0.12577796 0.3150137
## [3,] -0.2217832 -0.13249880 0.1498997
##
## , , 1, 25
##
## [,1] [,2] [,3]
## [1,] -0.2376153 -0.02401282 -0.0495364
## [2,] -0.1005466 0.16798191 0.1943555
## [3,] 0.0953801 0.15950544 -0.1368785
##
## , , 1, 26
##
## [,1] [,2] [,3]
## [1,] 0.10566790 -0.10734998 0.10212740
## [2,] 0.06916958 0.04824210 0.02064957
## [3,] 0.10526152 0.02690731 0.10656796
##
## , , 1, 27
##
## [,1] [,2] [,3]
## [1,] 0.08560913 0.1169470 -0.1710785
## [2,] 0.11514548 0.1530292 -0.1579792
## [3,] 0.02983252 0.1308075 -0.1988781
##
## , , 1, 28
##
## [,1] [,2] [,3]
## [1,] 0.01515060 -0.06721498 -0.1894968
## [2,] -0.02067504 0.21945311 -0.2427155
## [3,] -0.12837163 0.16743895 0.1641244
##
## , , 1, 29
##
## [,1] [,2] [,3]
## [1,] 0.1465970 0.08220476 -0.30966446
## [2,] 0.3616001 -0.16875626 -0.29371810
## [3,] -0.1944017 -0.32792708 -0.04309321
##
## , , 1, 30
##
## [,1] [,2] [,3]
## [1,] 0.16257612 0.18305269 0.08303798
## [2,] 0.07917408 0.14489332 -0.02815112
## [3,] 0.02105355 0.05869713 0.09788346
##
## , , 1, 31
##
## [,1] [,2] [,3]
## [1,] 0.1871261 0.25106141 0.1206731
## [2,] -0.1784246 -0.21292765 -0.1594458
## [3,] -0.1832155 -0.02868943 0.1094614
##
## , , 1, 32
##
## [,1] [,2] [,3]
## [1,] -0.24961968 -0.1156288 0.19642599
## [2,] -0.06554846 0.2535471 0.08002317
## [3,] 0.09769563 -0.1490398 0.06434184