Basic Idea of Generative Adversarial Networks (GAN) Machine Learning: Torch in R

1 Generative Adversarial Networks

Generative adversarial network is a fundamental class of (positive or negative) feedback loop controlling machine learning system by integrating two machine learning agents. First agent player is so-called generative model, actually serving as the prediction model. The second agent player is so-called discriminative model, actually serving as model verification model. One basic version of discriminative model can be implemented with the supervised label object classification. With the output of the previous minibatch generative model results as the input of the current minibatch discriminative model for error-correction or error-adjustment in each minibatch iterative process, the idea of integrated two agent players was the first proposed by Goodfellow et al. (2014). Nowadays, the class of GAN model have many modified versions such as Conditional GAN(Mirza and Osindero 2014), Progressive GAN (Karras et al. 2017), StyleGAN (Karras, Laine, and Aila 2018), In the section, we only discuss how to implement the fundamental class of GAN model with Rstudio. For those who are interested in other modified version, please refer to the reference.

The following figure is a feedback control-loop engineering system, the idea of which is translated to the deep machine learning architecture in terms of GAN. Simply speaking, GAN Machine Learning architecture is the integration of feedback loop control with CNN architecture. If it is a positive feedback loop control GAN, the discriminator classification output should be true-positive dominance (which is called Deeptrue); otherwise the output should be true-negative dominance (called Deepfaked). This integration design is nothing new from engineering control points of views. The following figure is to illustrate the basic idea of a feedback control-loop GAN architecture. Whether the discriminator classification image output is Deeptrue or Deepfaked, all depends on the design of applications. For examples, for recognition of medical MRI images, it should design as true-positive dominance (Deeptrue). For digital cloning effect in entertainment industry, it should design as true-negative dominance (Deepfaked). Nowadays Deeptrue and Deepfaked Technology with GAN are widely applied to medical science and entertainment industry respectively.

2 Install and Load Rtorch Library(torch)

  • torch library package in R based on the paper of Paszke et al. (2019) Pytorch
install.packages("torch")
install.packages("torchvision")
devtools::install_github("mlverse/torch")

2.1 Check CPU dervice ready using torch

library(torch)
library(torchvision)
torch_tensor(1, device = "cuda11.1")
## torch_tensor
##  1
## [ CPUFloatType{1} ]

3 Using Kuzushiji-MNIST dataset (KMNIST)

Samples of Kuzushiji-MNIST

Figure 3.1: Samples of Kuzushiji-MNIST

4 Load Data(Real Intput) and Transform_to_Tensor Format

Input Real Image Asian Characters

Figure 4.1: Input Real Image Asian Characters

5 transform_to_tensor() Function

  • is to normalize the the input image pixels to the range between 0 and 1
  • after then is to add one-extra channel dimension (z) as 3-dimensional vector space.
train_ds
## <kminst_dataset>
##   Inherits from: <mnist>
##   Public:
##     .getitem: function (index) 
##     .length: function () 
##     check_exists: function () 
##     classes: o ki su tsu na ha ma ya re wo
##     clone: function (deep = FALSE) 
##     data: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  ...
##     download: function () 
##     initialize: function (root, train = TRUE, transform = NULL, target_transform = NULL, 
##     processed_folder: active binding
##     raw_folder: active binding
##     resources: list
##     root_path: .
##     target_transform: NULL
##     targets: 9 8 1 2 5 3 5 9 2 2 6 2 1 6 8 7 2 8 10 6 8 4 8 6 7 7 3 8 ...
##     test_file: test.rds
##     train: TRUE
##     training_file: training.rds
##     transform: function (img)
test_ds
## <kminst_dataset>
##   Inherits from: <mnist>
##   Public:
##     .getitem: function (index) 
##     .length: function () 
##     check_exists: function () 
##     classes: o ki su tsu na ha ma ya re wo
##     clone: function (deep = FALSE) 
##     data: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  ...
##     download: function () 
##     initialize: function (root, train = TRUE, transform = NULL, target_transform = NULL, 
##     processed_folder: active binding
##     raw_folder: active binding
##     resources: list
##     root_path: .
##     target_transform: NULL
##     targets: 3 10 4 9 4 4 9 4 3 6 7 4 4 4 2 6 5 9 7 4 8 6 8 6 8 1 4 6 ...
##     test_file: test.rds
##     train: FALSE
##     training_file: training.rds
##     transform: function (img)

6 Display Tensor Output Format After Transformation

train_ds[1]
## $x
## torch_tensor
## (1,.,.) = 
##  Columns 1 to 9  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.2000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0431  0.8314
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.5804  1.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.8039  1.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0275  0.9608  1.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0902  1.0000  0.9373
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.1882  1.0000  0.7529
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.2078  1.0000  0.5804
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.3725  1.0000  0.5608
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.5843  1.0000  0.4745
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.7922  1.0000  0.8353
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0627  0.9647  1.0000  1.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.3490  1.0000  1.0000  1.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.7529  1.0000  1.0000  0.9961
##   0.0000  0.0000  0.0000  0.0000  0.1765  0.9882  1.0000  1.0000  0.9882
##   0.0000  0.0000  0.0000  0.0000  0.6314  1.0000  1.0000  1.0000  0.9843
##   0.0000  0.0000  0.0000  0.2196  0.9882  1.0000  1.0000  1.0000  0.9765
##   0.0000  0.0000  0.0000  0.7059  1.0000  1.0000  1.0000  1.0000  0.9804
##   0.0000  0.0000  0.0588  0.9882  1.0000  1.0000  1.0000  1.0000  0.9294
##   0.0000  0.0000  0.0667  0.9961  0.9529  0.8784  1.0000  1.0000  0.9373
##   0.0000  0.0000  0.0196  0.5882  0.2745  0.4039  1.0000  1.0000  0.6588
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.2549  1.0000  0.9961  0.2431
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.1176  0.9922  0.7294  0.0078
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0314  0.5647  0.1020  0.0000
## 
## ... [the output was truncated (use n=-1 to disable)]
## [ CPUFloatType{1,28,28} ]
## 
## $y
## [1] 9
test_ds[1]
## $x
## torch_tensor
## (1,.,.) = 
##  Columns 1 to 9  0.0000  0.3804  0.1373  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.1333  0.5608  0.0196  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.1922  0.3490  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.5255  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.6784  0.0196  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.6431  0.0314  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.8000  0.1176  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.8549  0.0784  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0824  0.8235  0.0078  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.3647  0.7412  0.0000  0.0000  0.0000  0.0000  0.2078
##   0.0000  0.0000  0.6667  0.5529  0.0000  0.0000  0.0000  0.0353  0.7137
##   0.0000  0.0314  0.8902  0.3255  0.0000  0.0000  0.0000  0.4353  0.5647
##   0.0000  0.2118  0.9882  0.1373  0.0000  0.0000  0.0588  0.8314  0.0902
##   0.0000  0.4627  0.8549  0.0118  0.0000  0.0000  0.3412  0.7686  0.0039
##   0.0000  0.7059  0.6314  0.0000  0.0000  0.0039  0.7569  0.4941  0.0000
##   0.0275  0.9176  0.5216  0.0000  0.0000  0.1098  0.9373  0.1294  0.0000
##   0.1725  1.0000  0.3843  0.0000  0.0000  0.5216  0.7608  0.0078  0.0000
##   0.2431  1.0000  0.3098  0.0000  0.0314  0.8784  0.6000  0.0000  0.0000
##   0.4667  0.9882  0.1804  0.0000  0.2745  1.0000  0.2745  0.0000  0.0000
##   0.5922  0.9961  0.2000  0.0000  0.5922  0.9333  0.0431  0.0000  0.0000
##   0.7216  1.0000  0.2039  0.0941  0.9451  0.7137  0.0000  0.0000  0.0000
##   0.7020  0.9765  0.1373  0.4431  1.0000  0.4471  0.0000  0.0000  0.0000
##   0.6745  0.9922  0.2863  0.8314  1.0000  0.2471  0.0000  0.0000  0.0000
##   0.4314  1.0000  0.9725  1.0000  0.9804  0.0745  0.0000  0.0000  0.0000
##   0.0902  0.9451  1.0000  1.0000  0.8745  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.4941  0.9961  1.0000  0.6275  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0588  0.8784  0.9490  0.2157  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.2667  0.4118  0.0000  0.0000  0.0000  0.0000  0.0000
## 
## ... [the output was truncated (use n=-1 to disable)]
## [ CPUFloatType{1,28,28} ]
## 
## $y
## [1] 3
train_ds[1][[1]]$size()
## [1]  1 28 28
test_ds[1][[1]]$size()
## [1]  1 28 28

7 Data Loader

train_dl <- dataloader(train_ds, batch_size = 32, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 32)
train_iter <- train_dl$.iter()
train_iter$.next()
## $x
## torch_tensor
## (1,1,.,.) = 
##  Columns 1 to 9  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0157  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.3961
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.2980  0.6314
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0078  0.2549  0.0118
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0627  0.1020
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.2510  0.9804  1.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.1333  0.9804  0.9922
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0196  0.6510  0.4000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0863  0.4667
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0157  0.7176  0.9882
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.3725  0.9882  0.4745
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.7922  1.0000  0.2471
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.7294  1.0000  0.5098
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.2902  0.9922  0.9294
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0118  0.6941  1.0000
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0667  0.8314
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.2353
##   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0039
## 
## ... [the output was truncated (use n=-1 to disable)]
## [ CPUFloatType{32,1,28,28} ]
## 
## $y
## torch_tensor
##   9
##   1
##  10
##   6
##  10
##   6
##   9
##   6
##   1
##   7
##   2
##   6
##   7
##   5
##   8
##   1
##   8
##   5
##   1
##   2
##   8
##   7
##   1
##  10
##  10
##   8
##   4
##   9
##   3
##   5
## ... [the output was truncated (use n=-1 to disable)]
## [ CPULongType{32} ]

8 Display the sample image from Trainset

par(mfrow = c(4,4), mar = rep(0, 4))
images <- train_dl$.iter()$.next()[[1]][1:32, 1, , ] 
images %>%
  purrr::array_tree(1) %>%
  purrr::map(as.raster) %>%
  purrr::iwalk(~{plot(.x)})
Samples of Kuzushiji-MNIST

Figure 8.1: Samples of Kuzushiji-MNIST

9 Create Network

net <- nn_module(
  
  "KMNIST-CNN",
  
  initialize = function() {
    # in_channels, out_channels, kernel_size, stride = 1, padding = 0
    self$conv1 <- nn_conv2d(1, 32, 3)
    self$conv2 <- nn_conv2d(32, 64, 3)
    self$dropout1 <- nn_dropout2d(0.25)
    self$dropout2 <- nn_dropout2d(0.5)
    self$fc1 <- nn_linear(9216, 128)
    self$fc2 <- nn_linear(128, 10)
  },
  
  forward = function(x) {
    x %>% 
      self$conv1() %>%
      nnf_relu() %>%
      self$conv2() %>%
      nnf_relu() %>%
      nnf_max_pool2d(2) %>%
      self$dropout1() %>%
      torch_flatten(start_dim = 2) %>%
      self$fc1() %>%
      nnf_relu() %>%
      self$dropout2() %>%
      self$fc2()
  }
)
model <- net()
model$to(device = "cuda11.1")
model
## An `nn_module` containing 1,199,882 parameters.
## 
## -- Modules ---------------------------------------------------------------------
## * conv1: <nn_conv2d> #320 parameters
## * conv2: <nn_conv2d> #18,496 parameters
## * dropout1: <nn_dropout2d> #0 parameters
## * dropout2: <nn_dropout2d> #0 parameters
## * fc1: <nn_linear> #1,179,776 parameters
## * fc2: <nn_linear> #1,290 parameters

10 Training the Model

optimizer <- optim_adam(model$parameters)
optimizer
## <optim_adam>
##   Inherits from: <torch_Optimizer>
##   Public:
##     add_param_group: function (param_group) 
##     clone: function (deep = FALSE) 
##     defaults: list
##     initialize: function (params, lr = 0.001, betas = c(0.9, 0.999), eps = 1e-08, 
##     load_state_dict: function (state_dict) 
##     param_groups: list
##     state: State, R6
##     state_dict: function () 
##     step: function (closure = NULL) 
##     zero_grad: function () 
##   Private:
##     step_helper: function (closure, loop_fun)
  minloss <- c()

  coro::loop(for (indexnumber in train_dl) {
    optimizer$zero_grad()
    output <- model(indexnumber[[1]]$to(device = "cuda11.1"))
    loss <- nnf_cross_entropy(output, indexnumber[[2]]$to(device = "cuda11.1"))
    loss$backward()
    optimizer$step()
    minloss <- c(minloss, loss$item())
  })

  cat(sprintf("Loss at epoch %d: %3f\n", epoch, mean(minloss)))
}

Reference

Goodfellow, Ian J., Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. 2014. “Generative Adversarial Networks,” June. https://arxiv.org/abs/1406.2661v1.
Karras, Tero, Timo Aila, Samuli Laine, and Jaakko Lehtinen. 2017. “Progressive Growing of GANs for Improved Quality, Stability, and Variation.” CoRR abs/1710.10196. http://arxiv.org/abs/1710.10196.
Karras, Tero, Samuli Laine, and Timo Aila. 2018. “A Style-Based Generator Architecture for Generative Adversarial Networks.” CoRR abs/1812.04948. http://arxiv.org/abs/1812.04948.
Mirza, Mehdi, and Simon Osindero. 2014. “Conditional Generative Adversarial Nets.” CoRR abs/1411.1784. http://arxiv.org/abs/1411.1784.
Paszke, Adam, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, et al. 2019. “PyTorch: An Imperative Style, High-Performance Deep Learning Library.” CoRR abs/1912.01703. http://arxiv.org/abs/1912.01703.