Basic Idea of Generative Adversarial Networks (GAN) Machine Learning: Torch in R
1 Generative Adversarial Networks
Generative adversarial network is a fundamental class of (positive or negative) feedback loop controlling machine learning system by integrating two machine learning agents. First agent player is so-called generative model, actually serving as the prediction model. The second agent player is so-called discriminative model, actually serving as model verification model. One basic version of discriminative model can be implemented with the supervised label object classification. With the output of the previous minibatch generative model results as the input of the current minibatch discriminative model for error-correction or error-adjustment in each minibatch iterative process, the idea of integrated two agent players was the first proposed by Goodfellow et al. (2014). Nowadays, the class of GAN model have many modified versions such as Conditional GAN(Mirza and Osindero 2014), Progressive GAN (Karras et al. 2017), StyleGAN (Karras, Laine, and Aila 2018), In the section, we only discuss how to implement the fundamental class of GAN model with Rstudio. For those who are interested in other modified version, please refer to the reference.
The following figure is a feedback control-loop engineering system, the idea of which is translated to the deep machine learning architecture in terms of GAN. Simply speaking, GAN Machine Learning architecture is the integration of feedback loop control with CNN architecture. If it is a positive feedback loop control GAN, the discriminator classification output should be true-positive dominance (which is called Deeptrue); otherwise the output should be true-negative dominance (called Deepfaked). This integration design is nothing new from engineering control points of views. The following figure is to illustrate the basic idea of a feedback control-loop GAN architecture. Whether the discriminator classification image output is Deeptrue or Deepfaked, all depends on the design of applications. For examples, for recognition of medical MRI images, it should design as true-positive dominance (Deeptrue). For digital cloning effect in entertainment industry, it should design as true-negative dominance (Deepfaked). Nowadays Deeptrue and Deepfaked Technology with GAN are widely applied to medical science and entertainment industry respectively.
2 Install and Load Rtorch Library(torch)
- torch library package in R based on the paper of Paszke et al. (2019) Pytorch
install.packages("torch")
install.packages("torchvision")
devtools::install_github("mlverse/torch")2.1 Check CPU dervice ready using torch
library(torch)
library(torchvision)
torch_tensor(1, device = "cuda11.1")## torch_tensor
## 1
## [ CPUFloatType{1} ]
3 Using Kuzushiji-MNIST dataset (KMNIST)
- https://www.simonwenkel.com/2018/12/18/Kuzushiji-MNIST.html#kuzushiji-mnist
- downloaded from : https://github.com/rois-codh/kmnist
Figure 3.1: Samples of Kuzushiji-MNIST
4 Load Data(Real Intput) and Transform_to_Tensor Format
Figure 4.1: Input Real Image Asian Characters
5 transform_to_tensor() Function
- is to normalize the the input image pixels to the range between 0 and 1
- after then is to add one-extra channel dimension (z) as 3-dimensional vector space.
train_ds## <kminst_dataset>
## Inherits from: <mnist>
## Public:
## .getitem: function (index)
## .length: function ()
## check_exists: function ()
## classes: o ki su tsu na ha ma ya re wo
## clone: function (deep = FALSE)
## data: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
## download: function ()
## initialize: function (root, train = TRUE, transform = NULL, target_transform = NULL,
## processed_folder: active binding
## raw_folder: active binding
## resources: list
## root_path: .
## target_transform: NULL
## targets: 9 8 1 2 5 3 5 9 2 2 6 2 1 6 8 7 2 8 10 6 8 4 8 6 7 7 3 8 ...
## test_file: test.rds
## train: TRUE
## training_file: training.rds
## transform: function (img)
test_ds## <kminst_dataset>
## Inherits from: <mnist>
## Public:
## .getitem: function (index)
## .length: function ()
## check_exists: function ()
## classes: o ki su tsu na ha ma ya re wo
## clone: function (deep = FALSE)
## data: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
## download: function ()
## initialize: function (root, train = TRUE, transform = NULL, target_transform = NULL,
## processed_folder: active binding
## raw_folder: active binding
## resources: list
## root_path: .
## target_transform: NULL
## targets: 3 10 4 9 4 4 9 4 3 6 7 4 4 4 2 6 5 9 7 4 8 6 8 6 8 1 4 6 ...
## test_file: test.rds
## train: FALSE
## training_file: training.rds
## transform: function (img)
6 Display Tensor Output Format After Transformation
train_ds[1]## $x
## torch_tensor
## (1,.,.) =
## Columns 1 to 9 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.2000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0431 0.8314
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.5804 1.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.8039 1.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0275 0.9608 1.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0902 1.0000 0.9373
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.1882 1.0000 0.7529
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.2078 1.0000 0.5804
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.3725 1.0000 0.5608
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.5843 1.0000 0.4745
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.7922 1.0000 0.8353
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0627 0.9647 1.0000 1.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.3490 1.0000 1.0000 1.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.7529 1.0000 1.0000 0.9961
## 0.0000 0.0000 0.0000 0.0000 0.1765 0.9882 1.0000 1.0000 0.9882
## 0.0000 0.0000 0.0000 0.0000 0.6314 1.0000 1.0000 1.0000 0.9843
## 0.0000 0.0000 0.0000 0.2196 0.9882 1.0000 1.0000 1.0000 0.9765
## 0.0000 0.0000 0.0000 0.7059 1.0000 1.0000 1.0000 1.0000 0.9804
## 0.0000 0.0000 0.0588 0.9882 1.0000 1.0000 1.0000 1.0000 0.9294
## 0.0000 0.0000 0.0667 0.9961 0.9529 0.8784 1.0000 1.0000 0.9373
## 0.0000 0.0000 0.0196 0.5882 0.2745 0.4039 1.0000 1.0000 0.6588
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.2549 1.0000 0.9961 0.2431
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.1176 0.9922 0.7294 0.0078
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0314 0.5647 0.1020 0.0000
##
## ... [the output was truncated (use n=-1 to disable)]
## [ CPUFloatType{1,28,28} ]
##
## $y
## [1] 9
test_ds[1]## $x
## torch_tensor
## (1,.,.) =
## Columns 1 to 9 0.0000 0.3804 0.1373 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.1333 0.5608 0.0196 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.1922 0.3490 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.5255 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.6784 0.0196 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.6431 0.0314 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.8000 0.1176 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.8549 0.0784 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0824 0.8235 0.0078 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.3647 0.7412 0.0000 0.0000 0.0000 0.0000 0.2078
## 0.0000 0.0000 0.6667 0.5529 0.0000 0.0000 0.0000 0.0353 0.7137
## 0.0000 0.0314 0.8902 0.3255 0.0000 0.0000 0.0000 0.4353 0.5647
## 0.0000 0.2118 0.9882 0.1373 0.0000 0.0000 0.0588 0.8314 0.0902
## 0.0000 0.4627 0.8549 0.0118 0.0000 0.0000 0.3412 0.7686 0.0039
## 0.0000 0.7059 0.6314 0.0000 0.0000 0.0039 0.7569 0.4941 0.0000
## 0.0275 0.9176 0.5216 0.0000 0.0000 0.1098 0.9373 0.1294 0.0000
## 0.1725 1.0000 0.3843 0.0000 0.0000 0.5216 0.7608 0.0078 0.0000
## 0.2431 1.0000 0.3098 0.0000 0.0314 0.8784 0.6000 0.0000 0.0000
## 0.4667 0.9882 0.1804 0.0000 0.2745 1.0000 0.2745 0.0000 0.0000
## 0.5922 0.9961 0.2000 0.0000 0.5922 0.9333 0.0431 0.0000 0.0000
## 0.7216 1.0000 0.2039 0.0941 0.9451 0.7137 0.0000 0.0000 0.0000
## 0.7020 0.9765 0.1373 0.4431 1.0000 0.4471 0.0000 0.0000 0.0000
## 0.6745 0.9922 0.2863 0.8314 1.0000 0.2471 0.0000 0.0000 0.0000
## 0.4314 1.0000 0.9725 1.0000 0.9804 0.0745 0.0000 0.0000 0.0000
## 0.0902 0.9451 1.0000 1.0000 0.8745 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.4941 0.9961 1.0000 0.6275 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0588 0.8784 0.9490 0.2157 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.2667 0.4118 0.0000 0.0000 0.0000 0.0000 0.0000
##
## ... [the output was truncated (use n=-1 to disable)]
## [ CPUFloatType{1,28,28} ]
##
## $y
## [1] 3
train_ds[1][[1]]$size()## [1] 1 28 28
test_ds[1][[1]]$size()## [1] 1 28 28
7 Data Loader
train_dl <- dataloader(train_ds, batch_size = 32, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 32)train_iter <- train_dl$.iter()
train_iter$.next()## $x
## torch_tensor
## (1,1,.,.) =
## Columns 1 to 9 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0157 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.3961
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.2980 0.6314
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0078 0.2549 0.0118
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0627 0.1020
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.2510 0.9804 1.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.1333 0.9804 0.9922
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0196 0.6510 0.4000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0863 0.4667
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0157 0.7176 0.9882
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.3725 0.9882 0.4745
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.7922 1.0000 0.2471
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.7294 1.0000 0.5098
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.2902 0.9922 0.9294
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0118 0.6941 1.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0667 0.8314
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.2353
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0039
##
## ... [the output was truncated (use n=-1 to disable)]
## [ CPUFloatType{32,1,28,28} ]
##
## $y
## torch_tensor
## 9
## 1
## 10
## 6
## 10
## 6
## 9
## 6
## 1
## 7
## 2
## 6
## 7
## 5
## 8
## 1
## 8
## 5
## 1
## 2
## 8
## 7
## 1
## 10
## 10
## 8
## 4
## 9
## 3
## 5
## ... [the output was truncated (use n=-1 to disable)]
## [ CPULongType{32} ]
8 Display the sample image from Trainset
par(mfrow = c(4,4), mar = rep(0, 4))
images <- train_dl$.iter()$.next()[[1]][1:32, 1, , ]
images %>%
purrr::array_tree(1) %>%
purrr::map(as.raster) %>%
purrr::iwalk(~{plot(.x)})Figure 8.1: Samples of Kuzushiji-MNIST
9 Create Network
net <- nn_module(
"KMNIST-CNN",
initialize = function() {
# in_channels, out_channels, kernel_size, stride = 1, padding = 0
self$conv1 <- nn_conv2d(1, 32, 3)
self$conv2 <- nn_conv2d(32, 64, 3)
self$dropout1 <- nn_dropout2d(0.25)
self$dropout2 <- nn_dropout2d(0.5)
self$fc1 <- nn_linear(9216, 128)
self$fc2 <- nn_linear(128, 10)
},
forward = function(x) {
x %>%
self$conv1() %>%
nnf_relu() %>%
self$conv2() %>%
nnf_relu() %>%
nnf_max_pool2d(2) %>%
self$dropout1() %>%
torch_flatten(start_dim = 2) %>%
self$fc1() %>%
nnf_relu() %>%
self$dropout2() %>%
self$fc2()
}
)model <- net()
model$to(device = "cuda11.1")
model## An `nn_module` containing 1,199,882 parameters.
##
## -- Modules ---------------------------------------------------------------------
## * conv1: <nn_conv2d> #320 parameters
## * conv2: <nn_conv2d> #18,496 parameters
## * dropout1: <nn_dropout2d> #0 parameters
## * dropout2: <nn_dropout2d> #0 parameters
## * fc1: <nn_linear> #1,179,776 parameters
## * fc2: <nn_linear> #1,290 parameters
10 Training the Model
optimizer <- optim_adam(model$parameters)
optimizer## <optim_adam>
## Inherits from: <torch_Optimizer>
## Public:
## add_param_group: function (param_group)
## clone: function (deep = FALSE)
## defaults: list
## initialize: function (params, lr = 0.001, betas = c(0.9, 0.999), eps = 1e-08,
## load_state_dict: function (state_dict)
## param_groups: list
## state: State, R6
## state_dict: function ()
## step: function (closure = NULL)
## zero_grad: function ()
## Private:
## step_helper: function (closure, loop_fun)
minloss <- c()
coro::loop(for (indexnumber in train_dl) {
optimizer$zero_grad()
output <- model(indexnumber[[1]]$to(device = "cuda11.1"))
loss <- nnf_cross_entropy(output, indexnumber[[2]]$to(device = "cuda11.1"))
loss$backward()
optimizer$step()
minloss <- c(minloss, loss$item())
})
cat(sprintf("Loss at epoch %d: %3f\n", epoch, mean(minloss)))
}