When using stat_summary() with a transformed y-axis, my expectation is that summary statistics are calculated on the original values, and then this result is transformed appropriately to appear on the y-axis.
I have discovered that this is not true, and that stat_summary() transforms all the values first and then calculates the summary!
When plotted, this produces different results to what I expect/want. Is this a bug or a feature or ggplot2?
Below I’ll try and illustrate the problem, and propose a solution.
ggplot2::stat_summaryggplot2 has the ability to summarise data with stat_summary. This particular Stat will calculate a summary of your data at each unique x value.
The following creates a scatter plot of some points with a mean calculated at each x and connected by a line.
Note:
x=0 is 1plot_df data and the manually calculated mean_df summary are included in the appendix at the end of this post.p <- ggplot(plot_df, aes(x, y)) +
geom_point(alpha=0.4) +
stat_summary(fun.y = mean, geom='line') +
scale_y_continuous(breaks=1:9) +
theme_bw() +
ggtitle("Using stat_summary to draw a mean line\nNote that mean is 1 at x=1")
p + scale_y_sqrt(breaks=1:9) +
ggtitle("Using stat_summary to draw a mean line - `scale_y_sqrt()`\nNote that mean at x=1 is no longer 1 !!")
My issue with stat_summary is that I expect the summary values to be calcualted before the transform is performed, but stat_summary summary values are calculated after the data is transformed.
The problem with this change in order-of-operations is that the summary-of-a-transformed-value isn’t the same as the transform-of-a-summary-value. e.g. mean(sqrt(x)) != sqrt(mean(x))
summary-after-transform and transform-after-summary?If we consider the y values at x=0, we can calculate the transformed mean in the two ways and see clearly that the results are different.
y <- c(0, 0, 0, 0, 1, 1, 5)
sqrt(mean(y)) # transform-after-summary (what i expected)
[1] 1
mean(sqrt(y)) # summary-after-transform (what stat_summary does)
[1] 0.6051526
So how does stat_summary way of calculating the summary-after-transform differ from the expected calculation of transform-after-summary?
The following plot shows the correct mean (dashed blue line) and the the stat_summary mean line. The correct mean has been calculated manually on the original data prior to any transformation.
Note:
x=0 the correct mean (blue) passes through 1, whereas the stat_summary() mean (black) passes through 0.6.ggplot(plot_df, aes(x, y)) +
geom_point(alpha=0.6) +
stat_summary(fun.y = mean, geom='line') +
geom_point(data=mean_df, size=3, alpha=0.3, colour='blue') +
geom_line (data=mean_df, linetype = 2, colour='blue') +
scale_y_sqrt(breaks=0:9) +
ggtitle("Blue dashed line is true mean line (transform-after-summary)\nBlack line is stat_summary result (summary-after-transform)") +
theme_bw()
stat_summary_two - do transform-after-summaryBelow is my adapted version of stat_summary - the key user-facing change is the addition transform.after.summary option. This defaults to TRUE, but can be set to FALSE to mimic the behaviour of the original stat_summary.
By the time ggplot processing gets to compute_panel it has already transformed the data. To work around this, we manually apply the inverse transform to get back the original data. This original data is them summarised and the results are transformed back into the requested scale.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#' Summarise y values at unique/binned x
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
stat_summary_two <- function(mapping = NULL, data = NULL,
geom = "pointrange", position = "identity",
...,
fun.data = NULL,
fun.y = NULL,
fun.ymax = NULL,
fun.ymin = NULL,
fun.args = list(),
na.rm = FALSE,
transform.after.summary = TRUE,
show.legend = NA,
inherit.aes = TRUE) {
layer(
data = data,
mapping = mapping,
stat = StatSummaryTwo,
geom = geom,
position = position,
show.legend = show.legend,
inherit.aes = inherit.aes,
params = list(
fun.data = fun.data,
fun.y = fun.y,
fun.ymax = fun.ymax,
fun.ymin = fun.ymin,
fun.args = fun.args,
na.rm = na.rm,
transform.after.summary = transform.after.summary,
...
)
)
}
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ggproto Stat
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
StatSummaryTwo <- ggproto(
"StatSummaryTwo", Stat,
required_aes = c("x", "y"),
compute_panel = function(data, scales, fun.data = NULL, fun.y = NULL,
fun.ymax = NULL, fun.ymin = NULL, fun.args = list(),
na.rm = FALSE, transform.after.summary = TRUE) {
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# The `data` we have in this function has already been transformed, so
# let's untransform it
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if (transform.after.summary) {
data$y <- scales$y$trans$inverse(data$y)
}
fun <- ggplot2:::make_summary_fun(fun.data, fun.y, fun.ymax, fun.ymin, fun.args)
res <- ggplot2:::summarise_by_x(data, fun)
if (transform.after.summary) {
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Transform the summary of the raw data into the final scale
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
res$y <- scales$y$trans$transform(res$y)
res$ymin <- scales$y$trans$transform(res$ymin)
res$ymax <- scales$y$trans$transform(res$ymax)
}
res
}
)
ggplot(plot_df, aes(x, y)) +
geom_point(alpha=0.6) +
stat_summary(fun.y = mean, geom='line') +
stat_summary_two(fun.y = mean, geom='line', colour='darkgreen', linetype=2) +
scale_y_sqrt(breaks=0:9) +
ggtitle("Green dotted line: correct mean result - `stat_summary_two` (transform-after-summary)\nBlack line: incorrect `stat_summary` result (summary-after-transform)") +
theme_bw()
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Create the plotting data.frame
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
plot_df <- data.frame(
x = rep(1:3, each=7),
y = c(0, 0, 0, 0, 1, 1, 5, 1, 1, 1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8)
)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Calculate the median at each time point
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
mean_df <- plot_df %>%
group_by(x) %>%
summarise(
y = mean(y),
) %>%
ungroup()
| x | y |
|---|---|
| 1 | 0 |
| 1 | 0 |
| 1 | 0 |
| 1 | 0 |
| 1 | 1 |
| 1 | 1 |
| 1 | 5 |
| 2 | 1 |
| 2 | 1 |
| 2 | 1 |
| 2 | 1 |
| 2 | 2 |
| 2 | 2 |
| 2 | 2 |
| 3 | 2 |
| 3 | 3 |
| 3 | 4 |
| 3 | 5 |
| 3 | 6 |
| 3 | 7 |
| x | y |
|---|---|
| 1 | 1.000000 |
| 2 | 1.428571 |
| 3 | 5.000000 |