import torch
import time

x = torch.randn(1000, 784)
w = torch.randn(784, 10)

t0 = time.time()
for _ in range(5000):
  o = torch.mm(x, w)
elapsed = time.time() - t0
library(torch)

x <- torch_randn(1000, 784)
w <- torch_randn(784, 10)

elapsed <- system.time({
  for (i in 1:5000) {
    o <- torch_mm(x, w)  
  }
})
x <- torch_randn(1000, 784)
w <- torch_randn(784, 10)

mm <- torch:::cpp_torch_namespace_mm_self_Tensor_mat2_Tensor
x_ptr <- x$ptr
w_ptr <- w$ptr

elapsed2 <- system.time({
  for (i in 1:5000) {
    o <- mm(x_ptr, w_ptr)
  }
})
x <- torch_randn(1000, 784)
w <- torch_randn(784, 10)

mm <- torch:::cpp_torch_namespace_mm_self_Tensor_mat2_Tensor

elapsed3 <- system.time({
  for (i in 1:5000) {
    o <- mm(x$ptr, w$ptr)
  }
})
tibble::tribble(
  ~Language, ~Time,
  "R"      , lubridate::duration(elapsed[["elapsed"]]),
  "R (ptr)", lubridate::duration(elapsed2[["elapsed"]]),
  "R (ptr, but retrieving)", lubridate::duration(elapsed3[["elapsed"]]),
  "Python" , lubridate::duration(reticulate::py$elapsed)
)
## # A tibble: 4 x 2
##   Language                Time              
##   <chr>                   <Duration>        
## 1 R                       4.208s            
## 2 R (ptr)                 0.553s            
## 3 R (ptr, but retrieving) 0.542999999999999s
## 4 Python                  0.606740951538086s
profvis::profvis({
  for (i in 1:5000) {
    o <- torch_mm(x, w)  
  }
})