This document shows how to use nested data frames with functional programming ideas to perform modelling at scale. Completely based on Hadley Wickham’s talk here and a transcription by Aaron Saunders here
library(dplyr)
library(tidyr)
library(broom)
library(purrr)
library(ggplot2)
by_country <- gapminder::gapminder %>%
mutate(year1950 = year - 1950) %>%
group_by(continent, country) %>%
nest()
Set up the modelling function that can be used to map to the data frames
country_model <- function(df){
lm(lifeExp ~ year1950, data = df)
}
And then create the models using this function, while extracting information from the models using broom.
models <- by_country %>%
mutate(model = map(data, country_model)) %>%
# glance gets the model summary variables
mutate(glance = map(model, broom::glance),
rsq = glance %>% map_dbl("r.squared"),
# tidy gets summary variables of the model coefficients
tidy = map(model, broom::tidy),
# augment gets per observation variables (e.g., residuals)
augment = map(model, broom::augment))
Plot the statistics you’re interested in relating to the data and models.
models %>%
ggplot(aes(rsq, reorder(country, rsq))) +
geom_point(aes(colour = continent)) +
labs(y = "Country") +
theme_minimal() +
theme(axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
panel.grid = element_blank(),
panel.grid.major.x = element_line(colour = "grey"),
panel.background = element_blank())
You can “unnest” a nested list to access it.
models %>% unnest(tidy)
## # A tibble: 284 x 8
## continent country rsq term estimate std.error
## <fctr> <fctr> <dbl> <chr> <dbl> <dbl>
## 1 Asia Afghanistan 0.9477123 (Intercept) 29.3566375 0.698981278
## 2 Asia Afghanistan 0.9477123 year1950 0.2753287 0.020450934
## 3 Europe Albania 0.9105778 (Intercept) 58.5597618 1.133575812
## 4 Europe Albania 0.9105778 year1950 0.3346832 0.033166387
## 5 Africa Algeria 0.9851172 (Intercept) 42.2364149 0.756269040
## 6 Africa Algeria 0.9851172 year1950 0.5692797 0.022127070
## 7 Africa Angola 0.8878146 (Intercept) 31.7079741 0.804287463
## 8 Africa Angola 0.8878146 year1950 0.2093399 0.023532003
## 9 Americas Argentina 0.9955681 (Intercept) 62.2250191 0.167091314
## 10 Americas Argentina 0.9955681 year1950 0.2317084 0.004888791
## # ... with 274 more rows, and 2 more variables: statistic <dbl>,
## # p.value <dbl>
We can unnest and wrangle data into a form that ggplot can use quite easily. Note that the area of the points scales with the R2 of the model, so the African countries near the bottom are probably not being modelled well by a steady increase in life expectancy over time.
models %>%
unnest(tidy) %>%
select(continent, country, term, estimate, rsq) %>%
spread(term, estimate) %>%
ggplot(aes(`(Intercept)`, year1950)) +
geom_point(aes(colour = continent, size = rsq)) +
geom_smooth(se = FALSE) +
scale_size_area() +
theme_minimal()
## `geom_smooth()` using method = 'loess'
Can you imagine trying to plot the residuals for these 142 models in a different fashion? Without being careful you could end up with some very complicated and difficult to read code. We can do this in a few lines using these nested data frames by again using unnest to “pluck out” the augment data frame and access it’s variables, which includes the residuals.
models %>%
unnest(augment) %>%
ggplot(aes(year1950, .resid)) +
geom_line(aes(group = country), alpha = 0.3) +
geom_hline(yintercept = 0, colour = "red", size = 1) +
geom_smooth(se = FALSE, size = 1.5) +
facet_wrap( ~ continent) +
theme_minimal()
## `geom_smooth()` using method = 'loess'
So we noticed some African countries that seemed like outliers in regards R2.
models %>% unnest(glance) %>% arrange(r.squared)
## # A tibble: 142 x 18
## continent country data model rsq
## <fctr> <fctr> <list> <list> <dbl>
## 1 Africa Rwanda <tibble [12 x 5]> <S3: lm> 0.01715964
## 2 Africa Botswana <tibble [12 x 5]> <S3: lm> 0.03402340
## 3 Africa Zimbabwe <tibble [12 x 5]> <S3: lm> 0.05623196
## 4 Africa Zambia <tibble [12 x 5]> <S3: lm> 0.05983644
## 5 Africa Swaziland <tibble [12 x 5]> <S3: lm> 0.06821087
## 6 Africa Lesotho <tibble [12 x 5]> <S3: lm> 0.08485635
## 7 Africa Cote d'Ivoire <tibble [12 x 5]> <S3: lm> 0.28337240
## 8 Africa South Africa <tibble [12 x 5]> <S3: lm> 0.31246865
## 9 Africa Uganda <tibble [12 x 5]> <S3: lm> 0.34215382
## 10 Africa Congo, Dem. Rep. <tibble [12 x 5]> <S3: lm> 0.34820278
## # ... with 132 more rows, and 13 more variables: tidy <list>,
## # augment <list>, r.squared <dbl>, adj.r.squared <dbl>, sigma <dbl>,
## # statistic <dbl>, p.value <dbl>, df <int>, logLik <dbl>, AIC <dbl>,
## # BIC <dbl>, deviance <dbl>, df.residual <int>
Let’s get a list of these and then have a look at their original data. We should notice a most certain non-linear shape to the relationship and Rwanda seeming to be an outlier even among these countries.
bad_fit <- models %>% unnest(glance) %>% filter(r.squared < 0.25)
by_country %>% filter(country %in% bad_fit$country) %>%
unnest(data) %>%
ggplot(aes(year, lifeExp, colour = country)) +
geom_line() +
theme_minimal()
As suggested by Hadley in his lecture, a quadratic might be a better fit. So here I show how we can generate new models with the same tidied statistics, but using a polynomial of rank 3 instead of a normal linear regression.
country_poly_model <- function(df){
lm(lifeExp ~ poly(year1950, 3), data = df)
}
poly_models <-
models %>% filter(country %in% bad_fit$country) %>%
mutate(model = map(data, country_poly_model)) %>%
mutate(glance = map(model, broom::glance),
rsq = glance %>% map_dbl("r.squared"),
tidy = map(model, broom::tidy),
augment = map(model, broom::augment))
models <-
models %>%
anti_join(poly_models, by = "country") %>%
bind_rows(poly_models)
rm(poly_models, bad_fit)
In retrospect, this pattern of mapping a new model to a subset, and then replacing it into the original data frame, should be wrapped up in a function or two.
Re-plotting the new models we can see the outlier African countries are now closer to the rest of the population. You could probably think of a way to identify which models are polynomial and which are linear in the plot.
models %>%
ggplot(aes(rsq, reorder(country, rsq))) +
geom_point(aes(colour = continent)) +
labs(y = "Country") +
theme_minimal() +
theme(axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
panel.grid = element_blank(),
panel.grid.major.x = element_line(colour = "grey"),
panel.background = element_blank())
NB There is some sort of bug around trying to unnest the augment nested data frame, where the model residuals live for example. Not sure if I’m missing something or the package itself has a problem.