import torch
import time
x = torch.randn(50000, 28, 28)
t0 = time.time()
for i in range(50000):
  a = x[i, :, :]
indexing = time.time() - t0 
fun <- function () {
  library(torch)
  x <- torch_randn(50000, 28, 28)
  system.time({
    for (i in 1:50000) {
      a <- x[i,,]
    }
  })
}
indexing <- callr::r(fun)
x = torch.randn(32, 784)
y = torch.randn(784, 10)
r = range(50000)
t0 = time.time()
for i in r:
  a = torch.mm(x, y)
simple_mul = time.time() - t0
fun <- function () {
  library(torch)
  x <- torch_randn(32, 784)
  y <- torch_randn(784, 1)
  el <- 1:50000
  system.time({
    for (i in el) {
      a = torch_mm(x, y)
    }
  })
}
simple_mul <- callr::r(fun)
class ConvMNIST (torch.nn.Module):
  
  def __init__ (self):
    super().__init__()
    self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
    self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
    self.dropout1 = torch.nn.Dropout(0.25)
    self.dropout2 = torch.nn.Dropout(0.5)
    self.fc1 = torch.nn.Linear(9216, 128)
    self.fc2 = torch.nn.Linear(128, 10)
  
  def forward (self, x):
    x = self.conv1(x)
    x = torch.nn.functional.relu(x)
    x = self.conv2(x)
    x = torch.nn.functional.relu(x)
    x = torch.nn.functional.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, start_dim = 1)
    x = self.fc1(x)
    x = torch.nn.functional.relu(x)
    x = self.dropout2(x)
    x = self.fc2(x)
    return x

model = ConvMNIST()

x = torch.randn(32, 1, 28, 28)
r = range(10000)
t0 = time.time()
for i in r:
  z = model(x)
simple_model = time.time() - t0
fun <- function() {
  library(torch)
  net <- nn_module(
    "Net",
    initialize = function() {
      self$conv1 <- nn_conv2d(1, 32, 3, 1)
      self$conv2 <- nn_conv2d(32, 64, 3, 1)
      self$dropout1 <- nn_dropout2d(0.25)
      self$dropout2 <- nn_dropout2d(0.5)
      self$fc1 <- nn_linear(9216, 128)
      self$fc2 <- nn_linear(128, 10)
    },
    forward = function(x) {
      x <- self$conv1(x)
      x <- nnf_relu(x)
      x <- self$conv2(x)
      x <- nnf_relu(x)
      x <- nnf_max_pool2d(x, 2)
      x <- self$dropout1(x)
      x <- torch_flatten(x, start_dim = 2)
      x <- self$fc1(x)
      x <- nnf_relu(x)
      x <- self$dropout2(x)
      x <- self$fc2(x)
      x
    }
  )
  
  model <- net()
  
  x <- torch_randn(32, 1, 28, 28)
  el <- 1:10000
  system.time({
    for (i in el) {
      z <- model(x)
    }
  })
}

simple_model <- callr::r(fun)
fun <- function() {
  library(torch)
  model <- nn_linear(784, 10)
  opt <- optim_sgd(model$parameters, lr = 0.01)
  x <- torch_randn(32, 784)
  y <- torch_randn(32, 10)
  el <- 1:10000
  system.time({
    for (i in el) {
      opt$zero_grad()
      pred <- model(x)
      loss <- nnf_mse_loss(y, pred)
      loss$backward()
      opt$step()
    }
  })
}
opt_step <- callr::r(fun)
model = torch.nn.Linear(784, 10)
opt = torch.optim.SGD(model.parameters(), lr = 0.01)
x  = torch.randn(32, 784)
y = torch.randn(32, 10)
r = range(10000)
t0 = time.time()

for i in r:
  opt.zero_grad()
  pred = model(x)
  loss = torch.nn.functional.mse_loss(y, pred)
  loss.backward()
  opt.step()

opt_step = time.time() - t0
tibble::tribble(
  ~Language, ~Time,
  "R (indexing)"      , lubridate::duration(indexing[["elapsed"]]),
  "Python (indexing)" , lubridate::duration(reticulate::py$indexing),
  "R (simple mul)"      , lubridate::duration(simple_mul[["elapsed"]]),
  "Python (simple_mul)" , lubridate::duration(reticulate::py$simple_mul),
  "R (simple model)"  , lubridate::duration(simple_model[["elapsed"]]),
  "Python (simple model)", lubridate::duration(reticulate::py$simple_model),
  "R (opt)"  , lubridate::duration(opt_step[["elapsed"]]),
  "Python (opt)", lubridate::duration(reticulate::py$opt_step)
)
## # A tibble: 8 x 2
##   Language              Time                            
##   <chr>                 <Duration>                      
## 1 R (indexing)          3.705s                          
## 2 Python (indexing)     0.281637907028198s              
## 3 R (simple mul)        4.321s                          
## 4 Python (simple_mul)   0.699391841888428s              
## 5 R (simple model)      282.549s (~4.71 minutes)        
## 6 Python (simple model) 183.91908288002s (~3.07 minutes)
## 7 R (opt)               24.058s                         
## 8 Python (opt)          2.4779679775238s
# fun <- function() {
  # library(torch)
  # x <- torch_randn(32, 784)
  # y <- torch_randn(784, 1)
  # el <- 1:10000
  # p <- profvis::profvis(prof_output = "prof.out", {
  #   for (i in el) {
  #     a = torch_mm(x, y)
  #   }
  # })
#   NULL
# }
# callr::r(fun)
# profvis::profvis(prof_input = "prof.out")