library(torch)

State <- R6::R6Class(
  "State",
  lock_objects = FALSE,
  public = list(
    initialize = function() {
      self$map <- fastmap::fastmap()
    },
    set = function(key, value) {
      self$map$set(rlang::hash(key), value)
    },
    get = function(key) {
      self$map$get(rlang::hash(key))
    }  
  )
)

StateList <- R6::R6Class(
  "State",
  lock_objects = FALSE,
  public = list(
    initialize = function() {
      self$map <- list()
    },
    set = function(key, value) {
      self$map[[rlang::hash(key)]] <- value
    },
    get = function(key) {
      self$map[[rlang::hash(key)]]
    }  
  )
)

StateEnv <- R6::R6Class(
  "State",
  lock_objects = FALSE,
  public = list(
    initialize = function() {
      self$map <- rlang::new_environment()
    },
    set = function(key, value) {
      self$map[[rlang::hash(key)]] <- value
    },
    get = function(key) {
      self$map[[rlang::hash(key)]]
    }  
  )
)

state_fast <- State$new()
state_list <- State$new()
state_env <- State$new()

bench::mark(
  fast = state_fast$set(torch_tensor(1), torch_tensor(1)),
  list = state_list$set(torch_tensor(1), torch_tensor(1)),
  envs = state_env$set(torch_tensor(1), torch_tensor(1))
)
## # A tibble: 3 x 6
##   expression      min   median `itr/sec` mem_alloc `gc/sec`
##   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
## 1 fast          153µs    167µs     5625.    50.8KB     16.9
## 2 list          155µs    172µs     5509.    18.6KB     20.2
## 3 envs          153µs    174µs     5450.    18.6KB     19.2
tensors <- lapply(1:100, function(x) torch_tensor(1))
values <- lapply(1:100, function(x) torch_tensor(1))

state_fast <- State$new()
state_list <- State$new()
state_env <- State$new()

for (i in 1:100) {
  fast = state_fast$set(tensors[[i]], values[[i]])
  list = state_list$set(tensors[[i]], values[[i]])
  envs = state_env$set(tensors[[i]], values[[i]])
}

get_from_state <- function(state) {
  lapply(tensors, function(x) state$get(x))
}

bench::mark(
  fast = get_from_state(state_fast),
  list = get_from_state(state_list),
  envs = get_from_state(state_env)
)
## # A tibble: 3 x 6
##   expression      min   median `itr/sec` mem_alloc `gc/sec`
##   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
## 1 fast          698µs    986µs      980.     864KB     18.7
## 2 list          687µs    878µs     1011.     879KB     20.9
## 3 envs          706µs    878µs      999.     864KB     18.5