# setup 
import torch
import torchvision
import time
dir = "~/Downloads/mnist-py"
ds = torchvision.datasets.MNIST(root=dir, download = True, transform= torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]))
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=32, shuffle=False)
t0 = time.time()
for a in dl:
  x = a[0]
  y = a[1]
elapsed = time.time() - t0
# setup
library(torch)
dir <- "~/Downloads/mnist"
ds <- mnist_dataset(dir)
dl <- dataloader(ds, batch_size = 32, shuffle = FALSE)
elapsed <- system.time({
  for (b in enumerate(dl)) {
    x <- b[[1]];
    y <- b[[2]]
  }
})
tibble::tribble(
  ~Language, ~Time,
  "R"      , lubridate::duration(elapsed[["elapsed"]]),
  "Python" , lubridate::duration(reticulate::py$elapsed)
)
## # A tibble: 2 x 2
##   Language Time             
##   <chr>    <Duration>       
## 1 R        19.026s          
## 2 Python   7.30162620544434s
profvis::profvis({
  for (b in enumerate(dl)) {
    x <- b[[1]];
    y <- b[[2]]
  }
})