if (!require(keras3))
{
  install.packages("keras3")
  library(keras3)
  install_keras()
}

if (!require(plotly)) install.packages("plotly")
d <- dataset_mnist()
str(d)
## List of 2
##  $ train:List of 2
##   ..$ x: int [1:60000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
##   ..$ y: int [1:60000(1d)] 5 0 4 1 9 2 1 3 1 4 ...
##  $ test :List of 2
##   ..$ x: int [1:10000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
##   ..$ y: int [1:10000(1d)] 7 2 1 0 4 1 4 9 5 9 ...
# クラスラベル
LAB.CLASS <- c('0','1','2','3','4','5','6','7','8','9')

(nclass <- length(LAB.CLASS))
## [1] 10
# 訓練データ
d.tr <- d$train
d.tr$fig <- d.tr$x / 255 # 規格化(0~255階調を0~1に変換)
d.tr$lab <- LAB.CLASS[d.tr$y + 1]
n.tr <- length(d.tr$lab)

# テストデータ
d.te <- d$test
d.te$fig <- d.te$x / 255 # 規格化(0~255階調を0~1に変換)
d.te$lab <- LAB.CLASS[d.te$y + 1]
n.te <- length(d.te$lab)

# 画像サイズ 28x28 ピクセル
(dx <- dim(d.tr$x)[2])
## [1] 28
(dy <- dim(d.tr$x)[3])
## [1] 28
# 描画自作関数
draw.images <- function(d, i.fr, i.to,
                        labhat = NA, p = NA, is.pred = F)
{
  par(mfrow = c(3, 6), # 行優先に3x6マスでプロット
      mar = c(4.5, 0, 1, 0) + 0.1, # 図周りのマージン設定
      cex.main = 0.9)

  for (i in i.fr:i.to)
  {
    plot(NA, xlim = c(0, dx), ylim = c(0, dy), axes = F,
        type = 'n', xlab = '', ylab = '', 
        main = paste('Fig.', i))

    rasterImage(d$fig[i, , ], 0, 0, dx, dy) 

    mtext(d$lab[i], side = 1, line = 0.2, adj = 0.5)
    
    # 予測時(is.pred == TRUE)は予測ラベルを貼り付ける。
    if (is.pred)
    {
      mtext(labhat[i], side = 1, line = 1.4, adj = 0.5, col = 4)
      mtext(sprintf('%3d%%', as.integer(p[i])), 
            col = 4, side = 1, line = 2.8, adj = 0.5)
    }
  }
}

# 描画
draw.images(d.tr, i.fr = 1, i.to = 18)

clear_session() # 古いモデルを削除

model <- keras_model_sequential(input_shape = c(dx, dy)) |> # 入力層
  layer_flatten() |> # 画像を1列のピクセルに変換
  layer_dense(units = 128,    activation = 'relu') |> # 中間層(ReLU)
  layer_dense(units =  64,    activation = 'relu') |> # 中間層(ReLU)
  layer_dense(units = nclass, activation = 'softmax') # 出力層(ソフトマックス)

# モデル概要
summary(model)
## Model: "sequential"
## ┌───────────────────────────────────┬──────────────────────────┬───────────────
## │ Layer (type)                      │ Output Shape             │       Param # 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ flatten (Flatten)                 │ (None, 784)              │             0 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ dense (Dense)                     │ (None, 128)              │       100,480 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ dense_1 (Dense)                   │ (None, 64)               │         8,256 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ dense_2 (Dense)                   │ (None, 10)               │           650 
## └───────────────────────────────────┴──────────────────────────┴───────────────
##  Total params: 109,386 (427.29 KB)
##  Trainable params: 109,386 (427.29 KB)
##  Non-trainable params: 0 (0.00 B)
# 高速演算のためのコンパイル(PCが素早く理解できる機械語に翻訳)
compile(model,
        loss      = 'sparse_categorical_crossentropy',     # 交差エントロピー関数
        optimizer = optimizer_adam(learning_rate = 0.001), # 最適化アルゴリズム 
        metrics   = c('accuracy'))                         # 評価指標:精度
# コールバック設定
callbacks <- list(
  # 早期停止(検証データでの損失値の改善が20エポック以上なかったら停止)
  callback_early_stopping(patience = 20, monitor = "val_loss"),
  
  # 検証データでの損失が改善されない限りモデルを上書きしない設定
  # (early_stoppingとセットで使用する)
  callback_model_checkpoint(filepath = "bestmodel.keras",
                            monitor = "val_loss", save_best_only = T),
  
  # 検証データでの損失が改善せず停滞した時(判定:5エポック)
  # に局所解を抜け出すため学習率を0.1倍に下げる設定。 
  callback_reduce_lr_on_plateau(monitor = "val_loss", 
                                factor = 0.1, patience = 5)
)

# フィッティング
fit(model,                  # モデル
    d.tr$x,                 # 入力(28x28画素データx60000)
    d.tr$y,                 # 目的変数
    verbose    = 0,         # 1:出力表示(低速),0:出力表示抑制
    batch_size = 2^5,       # バッチサイズ(要調整)
    epochs     = 100,       # エポック数
    validation_split = 0.2, # 検証用データ割合(訓練には不使用)
    callbacks  = callbacks) # コールバック設定

pred <- predict(model, d.te$x) # 予測結果(確率)
## 313/313 - 0s - 1ms/step
yhat <- max.col(pred) - 1      # 予測結果(エンコーディング値)
labhat <- LAB.CLASS[yhat + 1]  # 予測結果(クラスラベル)

p <- rep(NA, n.te)
for (i in 1:n.te) p[i] <- pred[i, yhat[i] + 1] * 100 # 確率 [%]

draw.images(d.te, i.fr = 1, i.to = 36, labhat, p, is.pred = T)

options(digits = 2)
evaluate(model, d.te$x, d.te$y)
## 313/313 - 0s - 1ms/step - accuracy: 0.9752 - loss: 0.1620
## $accuracy
## [1] 0.98
## 
## $loss
## [1] 0.16
# 混同行列
library(caret)
cm <- confusionMatrix(data = as.factor(yhat),
                      ref  = as.factor(d.te$y))
cm
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1    2    3    4    5    6    7    8    9
##          0  968    0    3    1    2    2    6    0    7    4
##          1    0 1116    5    0    1    0    2    4    2    3
##          2    1    4 1003    8    2    1    1    8    3    1
##          3    1    2    5  988    1   13    1    3    5    6
##          4    3    0    2    0  955    0    1    0    5    5
##          5    2    0    0    4    0  864    6    1    3    4
##          6    3    1    3    0    6    6  940    0    4    0
##          7    1    4    6    4    2    1    0 1006    4    7
##          8    1    8    5    2    2    1    1    2  937    4
##          9    0    0    0    3   11    4    0    4    4  975
## 
## Overall Statistics
##                                         
##                Accuracy : 0.975         
##                  95% CI : (0.972, 0.978)
##     No Information Rate : 0.114         
##     P-Value [Acc > NIR] : <2e-16        
##                                         
##                   Kappa : 0.972         
##                                         
##  Mcnemar's Test P-Value : NA            
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
## Sensitivity            0.9878    0.983    0.972   0.9782   0.9725   0.9686
## Specificity            0.9972    0.998    0.997   0.9959   0.9982   0.9978
## Pos Pred Value         0.9748    0.985    0.972   0.9639   0.9835   0.9774
## Neg Pred Value         0.9987    0.998    0.997   0.9975   0.9970   0.9969
## Prevalence             0.0980    0.114    0.103   0.1010   0.0982   0.0892
## Detection Rate         0.0968    0.112    0.100   0.0988   0.0955   0.0864
## Detection Prevalence   0.0993    0.113    0.103   0.1025   0.0971   0.0884
## Balanced Accuracy      0.9925    0.991    0.984   0.9871   0.9854   0.9832
##                      Class: 6 Class: 7 Class: 8 Class: 9
## Sensitivity            0.9812    0.979   0.9620   0.9663
## Specificity            0.9975    0.997   0.9971   0.9971
## Pos Pred Value         0.9761    0.972   0.9730   0.9740
## Neg Pred Value         0.9980    0.998   0.9959   0.9962
## Prevalence             0.0958    0.103   0.0974   0.1009
## Detection Rate         0.0940    0.101   0.0937   0.0975
## Detection Prevalence   0.0963    0.103   0.0963   0.1001
## Balanced Accuracy      0.9893    0.988   0.9796   0.9817