Setup
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tidymodels)) # for the recipes package, along with the rest of tidymodels
tidymodels_prefer()
Precision / Recall / F1 for multiclass
n_permutations <- 1000
n_samples_per_class <- 2
n_classes <- 160
truth <- rep(seq(1, n_classes), n_samples_per_class)
truth_fac <- as.factor(truth)
metrics_random <-
map_df(1:n_permutations,
function(i) {
estimate <- sample(truth)
estimate_fac <- as.factor(estimate)
tibble(
recall = recall_vec(
truth = truth_fac,
estimate = estimate_fac,
event_level = "second"
),
precision = precision_vec(
truth = truth_fac,
estimate = estimate_fac,
event_level = "second"
),
f_meas = f_meas_vec(
truth = truth_fac,
estimate = estimate_fac,
event_level = "second"
)
)
})
metrics_random %>%
pivot_longer(everything(), names_to = "metric") %>%
ggplot(aes(value)) +
geom_histogram(bins = 30) +
facet_wrap(~ metric) +
theme_bw()

metrics_random %>%
pivot_longer(everything(), names_to = "metric") %>%
ggplot(aes(metric, value)) +
geom_boxplot() +
theme_bw()

metrics_random %>%
summarize(across(everything(),
list(mean = mean, sd = sd))) %>%
pivot_longer(everything())
Average Precision for two-class
n_permutations <- 1000
n_samples <- 100
n_samples_class_min <- 3
n_samples_class_max <- 30
metrics_random <-
map_df(n_samples_class_min:n_samples_class_max,
function(n_samples_class) {
truth <-
c(rep(TRUE, n_samples_class),
rep(FALSE, n_samples - n_samples_class))
truth_fac <- as.factor(truth)
map_df(1:n_permutations,
function(i) {
estimate <- runif(length(truth))
threshold <-
sort(estimate,
partial = n_samples - n_samples_class + 1)[n_samples - n_samples_class + 1]
estimate_fac <- estimate >= threshold
stopifnot(sum(estimate_fac) == n_samples_class)
estimate_fac <- as.factor(estimate_fac)
tibble(
average_precision = average_precision_vec(
truth = truth_fac,
estimate = estimate,
event_level = "second"
),
precision_at_r = precision_vec(
truth = truth_fac,
estimate = estimate_fac,
event_level = "second"
),
)
}) %>%
mutate(n_samples_class = n_samples_class)
})
metrics_random %>%
pivot_longer(-n_samples_class, names_to = "metric") %>%
ggplot(aes(as.factor(n_samples_class), value)) +
geom_boxplot() +
facet_wrap(~ metric) +
theme_bw()

metrics_random %>%
group_by(n_samples_class) %>%
summarize(across(everything(),
list(mean = mean, sd = sd)))
LS0tCnRpdGxlOiAiTWV0cmljIHNpbXVsYXRpb25zIgpvdXRwdXQ6IGh0bWxfbm90ZWJvb2sKLS0tCgojIFNldHVwCgpgYGB7cn0Kc3VwcHJlc3NQYWNrYWdlU3RhcnR1cE1lc3NhZ2VzKGxpYnJhcnkodGlkeXZlcnNlKSkKc3VwcHJlc3NQYWNrYWdlU3RhcnR1cE1lc3NhZ2VzKGxpYnJhcnkodGlkeW1vZGVscykpICMgZm9yIHRoZSByZWNpcGVzIHBhY2thZ2UsIGFsb25nIHdpdGggdGhlIHJlc3Qgb2YgdGlkeW1vZGVscwp0aWR5bW9kZWxzX3ByZWZlcigpCmBgYAoKIyBQcmVjaXNpb24gLyBSZWNhbGwgLyBGMSBmb3IgbXVsdGljbGFzcwoKYGBge3J9Cm5fcGVybXV0YXRpb25zIDwtIDEwMDAKbl9zYW1wbGVzX3Blcl9jbGFzcyA8LSAyCm5fY2xhc3NlcyA8LSAxNjAKdHJ1dGggPC0gcmVwKHNlcSgxLCBuX2NsYXNzZXMpLCBuX3NhbXBsZXNfcGVyX2NsYXNzKQp0cnV0aF9mYWMgPC0gYXMuZmFjdG9yKHRydXRoKQoKbWV0cmljc19yYW5kb20gPC0KICBtYXBfZGYoMTpuX3Blcm11dGF0aW9ucywKICAgICAgICAgZnVuY3Rpb24oaSkgewogICAgICAgICAgIGVzdGltYXRlIDwtIHNhbXBsZSh0cnV0aCkKICAgICAgICAgICAKICAgICAgICAgICBlc3RpbWF0ZV9mYWMgPC0gYXMuZmFjdG9yKGVzdGltYXRlKQogICAgICAgICAgIAogICAgICAgICAgIHRpYmJsZSgKICAgICAgICAgICAgIHJlY2FsbCA9ICByZWNhbGxfdmVjKAogICAgICAgICAgICAgICB0cnV0aCA9IHRydXRoX2ZhYywKICAgICAgICAgICAgICAgZXN0aW1hdGUgPSBlc3RpbWF0ZV9mYWMsCiAgICAgICAgICAgICAgIGV2ZW50X2xldmVsID0gInNlY29uZCIKICAgICAgICAgICAgICksCiAgICAgICAgICAgICAKICAgICAgICAgICAgIHByZWNpc2lvbiA9ICBwcmVjaXNpb25fdmVjKAogICAgICAgICAgICAgICB0cnV0aCA9IHRydXRoX2ZhYywKICAgICAgICAgICAgICAgZXN0aW1hdGUgPSBlc3RpbWF0ZV9mYWMsCiAgICAgICAgICAgICAgIGV2ZW50X2xldmVsID0gInNlY29uZCIKICAgICAgICAgICAgICksCiAgICAgICAgICAgICAKICAgICAgICAgICAgIGZfbWVhcyA9ICBmX21lYXNfdmVjKAogICAgICAgICAgICAgICB0cnV0aCA9IHRydXRoX2ZhYywKICAgICAgICAgICAgICAgZXN0aW1hdGUgPSBlc3RpbWF0ZV9mYWMsCiAgICAgICAgICAgICAgIGV2ZW50X2xldmVsID0gInNlY29uZCIKICAgICAgICAgICAgICkKICAgICAgICAgICAgIAogICAgICAgICAgICkKICAgICAgICAgfSkKYGBgCgoKYGBge3J9Cm1ldHJpY3NfcmFuZG9tICU+JQogIHBpdm90X2xvbmdlcihldmVyeXRoaW5nKCksIG5hbWVzX3RvID0gIm1ldHJpYyIpICU+JQogIGdncGxvdChhZXModmFsdWUpKSArCiAgZ2VvbV9oaXN0b2dyYW0oYmlucyA9IDMwKSArCiAgZmFjZXRfd3JhcCh+IG1ldHJpYykgKwogIHRoZW1lX2J3KCkKCm1ldHJpY3NfcmFuZG9tICU+JQogIHBpdm90X2xvbmdlcihldmVyeXRoaW5nKCksIG5hbWVzX3RvID0gIm1ldHJpYyIpICU+JQogIGdncGxvdChhZXMobWV0cmljLCB2YWx1ZSkpICsKICBnZW9tX2JveHBsb3QoKSArCiAgdGhlbWVfYncoKQoKbWV0cmljc19yYW5kb20gJT4lCiAgc3VtbWFyaXplKGFjcm9zcyhldmVyeXRoaW5nKCksCiAgICAgICAgICAgICAgICAgICBsaXN0KG1lYW4gPSBtZWFuLCBzZCA9IHNkKSkpICU+JQogIHBpdm90X2xvbmdlcihldmVyeXRoaW5nKCkpCmBgYAoKIyBBdmVyYWdlIFByZWNpc2lvbiBmb3IgdHdvLWNsYXNzCgpgYGB7cn0Kbl9wZXJtdXRhdGlvbnMgPC0gMTAwMApuX3NhbXBsZXMgPC0gMTAwCm5fc2FtcGxlc19jbGFzc19taW4gPC0gMwpuX3NhbXBsZXNfY2xhc3NfbWF4IDwtIDMwCgptZXRyaWNzX3JhbmRvbSA8LQogIG1hcF9kZihuX3NhbXBsZXNfY2xhc3NfbWluOm5fc2FtcGxlc19jbGFzc19tYXgsCiAgICAgICAgIAogICAgICAgICBmdW5jdGlvbihuX3NhbXBsZXNfY2xhc3MpIHsKICAgICAgICAgICB0cnV0aCA8LQogICAgICAgICAgICAgYyhyZXAoVFJVRSwgbl9zYW1wbGVzX2NsYXNzKSwKICAgICAgICAgICAgICAgcmVwKEZBTFNFLCBuX3NhbXBsZXMgLSBuX3NhbXBsZXNfY2xhc3MpKQogICAgICAgICAgIAogICAgICAgICAgIHRydXRoX2ZhYyA8LSBhcy5mYWN0b3IodHJ1dGgpCiAgICAgICAgICAgCiAgICAgICAgICAgbWFwX2RmKDE6bl9wZXJtdXRhdGlvbnMsCiAgICAgICAgICAgICAgICAgIGZ1bmN0aW9uKGkpIHsKICAgICAgICAgICAgICAgICAgICBlc3RpbWF0ZSA8LSBydW5pZihsZW5ndGgodHJ1dGgpKQogICAgICAgICAgICAgICAgICAgIAogICAgICAgICAgICAgICAgICAgIHRocmVzaG9sZCA8LQogICAgICAgICAgICAgICAgICAgICAgc29ydChlc3RpbWF0ZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgcGFydGlhbCA9IG5fc2FtcGxlcyAtIG5fc2FtcGxlc19jbGFzcyArIDEpW25fc2FtcGxlcyAtIG5fc2FtcGxlc19jbGFzcyArIDFdCiAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgZXN0aW1hdGVfZmFjIDwtIGVzdGltYXRlID49IHRocmVzaG9sZAogICAgICAgICAgICAgICAgICAgIAogICAgICAgICAgICAgICAgICAgIHN0b3BpZm5vdChzdW0oZXN0aW1hdGVfZmFjKSA9PSBuX3NhbXBsZXNfY2xhc3MpCiAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgZXN0aW1hdGVfZmFjIDwtIGFzLmZhY3Rvcihlc3RpbWF0ZV9mYWMpCiAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgdGliYmxlKAogICAgICAgICAgICAgICAgICAgICAgYXZlcmFnZV9wcmVjaXNpb24gPSAgYXZlcmFnZV9wcmVjaXNpb25fdmVjKAogICAgICAgICAgICAgICAgICAgICAgICB0cnV0aCA9IHRydXRoX2ZhYywKICAgICAgICAgICAgICAgICAgICAgICAgZXN0aW1hdGUgPSBlc3RpbWF0ZSwKICAgICAgICAgICAgICAgICAgICAgICAgZXZlbnRfbGV2ZWwgPSAic2Vjb25kIgogICAgICAgICAgICAgICAgICAgICAgKSwKICAgICAgICAgICAgICAgICAgICAgIHByZWNpc2lvbl9hdF9yID0gIHByZWNpc2lvbl92ZWMoCiAgICAgICAgICAgICAgICAgICAgICAgIHRydXRoID0gdHJ1dGhfZmFjLAogICAgICAgICAgICAgICAgICAgICAgICBlc3RpbWF0ZSA9IGVzdGltYXRlX2ZhYywKICAgICAgICAgICAgICAgICAgICAgICAgZXZlbnRfbGV2ZWwgPSAic2Vjb25kIgogICAgICAgICAgICAgICAgICAgICAgKSwKICAgICAgICAgICAgICAgICAgICAgIAogICAgICAgICAgICAgICAgICAgICkKICAgICAgICAgICAgICAgICAgfSkgJT4lCiAgICAgICAgICAgICBtdXRhdGUobl9zYW1wbGVzX2NsYXNzID0gbl9zYW1wbGVzX2NsYXNzKQogICAgICAgICB9KQpgYGAKCgpgYGB7cn0KbWV0cmljc19yYW5kb20gJT4lCiAgcGl2b3RfbG9uZ2VyKC1uX3NhbXBsZXNfY2xhc3MsIG5hbWVzX3RvID0gIm1ldHJpYyIpICU+JQogIGdncGxvdChhZXMoYXMuZmFjdG9yKG5fc2FtcGxlc19jbGFzcyksIHZhbHVlKSkgKwogIGdlb21fYm94cGxvdCgpICsKICBmYWNldF93cmFwKH4gbWV0cmljKSArCiAgdGhlbWVfYncoKQoKbWV0cmljc19yYW5kb20gJT4lCiAgZ3JvdXBfYnkobl9zYW1wbGVzX2NsYXNzKSAlPiUKICBzdW1tYXJpemUoYWNyb3NzKGV2ZXJ5dGhpbmcoKSwKICAgICAgICAgICAgICAgICAgIGxpc3QobWVhbiA9IG1lYW4sIHNkID0gc2QpKSkKYGBgCg==