library(glue)
library(tidyverse)
M <- 4
N <- 30
#' Sample the distribution of information retrieval metrics under the null hypothesis
#'
#' @param m Number of positive examples (= number of replicates - 1)
#' @param n Number of negative examples (= number of controls, or number of non-replicates)
#' @param nn Number of simulations (default = 1000)
#'
#' @return sampling
#'
retrieval_metrics_null_sampling <- function(m, n, nn = 1000) {
y_rank <- 1 - (seq(m + n) / (m + n))
y_boolean <- as.factor(c(rep(TRUE, m), rep(FALSE, n)))
map_df(seq(nn), function(i) {
x <- as.factor(sample(c(rep(FALSE, n), rep(TRUE, m))))
ap <- yardstick::average_precision_vec(x, y_rank, event_level = "second")
pr <- yardstick::precision_vec(x, y_boolean, event_level = "second")
data.frame(ap, pr)
}) %>%
mutate(m = m, n = n)
}
future::plan(future::multisession, workers = 14)
retrieval_metrics_null <- seq(50) %>%
furrr::future_map_dfr(function(i)
retrieval_metrics_null_sampling(M * i, N * i, 10000),
.options = furrr::furrr_options(seed = TRUE))
retrieval_metrics_null %>%
write_csv("retrieval_metrics_null.csv.gz")
retrieval_metrics_null <-
read_csv("retrieval_metrics_null.csv.gz", show_col_types = FALSE)
Plot the densities of AP and PR
retrieval_metrics_null %>%
ggplot(aes(ap, group = (m + n), color = (m + n))) +
geom_density() +
ggtitle("Average Precision", subtitle = glue("M = {M} N = {N}"))
retrieval_metrics_null %>%
ggplot(aes(pr, group = (m + n), color = (m + n))) +
geom_density() +
ggtitle("Precision@R", subtitle = glue("M = {M} N = {N}"))
Plot summary statistics of AP and PR
summaries <-
retrieval_metrics_null %>%
pivot_longer(-one_of(c("m", "n")), names_to = "metric_name") %>%
group_by(m, n, metric_name) %>%
summarize(q95 = quantile(value, 0.95, names = FALSE),
q90 = quantile(value, 0.90, names = FALSE),
q75 = quantile(value, 0.75, names = FALSE),
avg = mean(value),
.groups = "keep"
) %>%
ungroup() %>%
pivot_longer(-one_of(c("m", "n", "metric_name")), names_to = "statistic")
summaries %>%
pivot_wider(id_cols = all_of(c("m", "n", "metric_name")),
names_from = "statistic",
values_from = "value") %>%
head()
summaries %>%
ggplot(aes(m, value, color = metric_name, linetype = statistic)) +
geom_line() +
geom_hline(yintercept = M / (M+N), alpha = 0.2, linewidth = 2) +
theme_bw() +
scale_x_continuous(breaks = scales::pretty_breaks(n = 10)) +
scale_y_continuous(breaks = scales::pretty_breaks(n = 10)) +
ggtitle("Average Precision and Precision@R summary statistics", subtitle = glue("M = {M} N = {N}")) +
labs(caption = "Horizonal gray line is M / (M+N)")
pow <- 1.3
max_value <- 300
break_point <- (seq(1, (max_value)^(1/pow), 1)**(pow))
data.frame(break_point) %>%
mutate(break_width = break_point - lag(break_point)) %>%
na.omit() %>%
ggplot(aes(break_point, break_width)) +
geom_line() +
geom_point() +
scale_x_continuous(breaks = scales::pretty_breaks(n = 10)) +
scale_y_continuous(breaks = scales::pretty_breaks(n = 10)) +
theme_bw() +
ggtitle("Nonlinear breakpoints", subtitle = glue("seq(1, {max_value}^(1/{pow}), 1)**({pow})"))