Modelling is an essential tool for visualisation. There are two particularly strong connections between modelling and visualisation that I want to explore in this chapter:
Using models as a tool to remove obvious patterns in your plots. This is useful because strong patterns mask subtler effects. Often the strongest effects are already known and expected, and removing them allows you to see surprises more easily.
Other times you have a lot of data, too much to show on a handful of plots. Models can be a powerful tool for summarising data so that you get a higher level view.
In this chapter, I’m going to focus on the use of linear models to acheive these goals. Linear models are a basic, but powerful, tool of statistics, and I recommend that everyone serious about visualisation learns at least the basics of how to use them. To this end, I highly recommend two books by Julian J. Faraway:
These books cover some of the theory of linear models, but are pragmatic and focussed on how to actually use linear models (and their extensions) in R.
There are many other modelling tools, which I don’t have the space to show. If you understand how linear models can help improve your visualisations, you should be able to translate the basic idea to other families of models. This chapter just scratches the surface of what you can do. But hopefully it reinforces how visualisation can combine with modelling to help you build a powerful data analysis toolbox. For more ideas, check out Wickham, Cook, and Hofmann (2015).
This chapter only scratches the surface of the intersection between visualisation and modelling. In my opinion, mastering the combination of visualisations and models is key to being an effective data scientist. Unfortunately most books (like this one!) only focus on either visualisation or modelling, but not both. There’s a lot of interesting work to be done.
So far our analysis of the diamonds data has been plagued by the powerful relationship between size and price. It makes it very difficult to see the impact of cut, colour and clarity because higher quality diamonds tend to be smaller, and hence cheaper. This challenge is often called confounding. We can use a linear model to remove the effect of size on price. Instead of looking at the raw price, we can look at the relative price: how valuable is this diamond relative to the average diamond of the same size.
To get started, we’ll focus on diamonds of size two carats or less (96% of the dataset). This avoids some incidental problems that you can explore in the exercises if you’re interested. We’ll also create two new variables: log price and log carat. These variables are useful because they produce a plot with a strong linear trend.
diamonds2 <- diamonds %>%
filter(carat <= 2) %>%
mutate(
lcarat = log2(carat),
lprice = log2(price)
)
diamonds2
#> # A tibble: 52,051 × 12
#> carat cut color clarity depth table price x y z lcarat
#> <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl>
#> 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43 -2.12
#> 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31 -2.25
#> 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31 -2.12
#> 4 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63 -1.79
#> 5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75 -1.69
#> 6 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48 -2.06
#> # ... with 5.204e+04 more rows, and 1 more variables: lprice <dbl>
ggplot(diamonds2, aes(lcarat, lprice)) +
geom_bin2d() +
geom_smooth(method = "lm", se = FALSE, size = 2, colour = "yellow")
In the graphic we used geom_smooth()
to overlay the line of best fit to the data. We can replicate this outside of ggplot2 by fitting a linear model with lm()
. This allows us to find out the slope and intercept of the line:
mod <- lm(lprice ~ lcarat, data = diamonds2)
coef(summary(mod))
#> Estimate Std. Error t value Pr(>|t|)
#> (Intercept) 12.2 0.00211 5789 0
#> lcarat 1.7 0.00208 816 0
If you’re familiar with linear models, you might want to interpret those coefficients: \(\log_2(price) = 12.2 + 1.7 \cdot \log_2(carat)\), which implies \(price = 4900 \cdot carat ^ {1.7}\). Interpreting those coefficients certainly is useful, but even if you don’t understand them, the model can still be useful. We can use it to subtract the trend away by looking at the residuals: the price of each diamond minus its predicted price, based on weight alone. Geometrically, the residuals are the vertical distance between each point and the line of best fit. They tell us the price relative to the “average” diamond of that size.
diamonds2 <- diamonds2 %>% mutate(rel_price = resid(mod))
ggplot(diamonds2, aes(carat, rel_price)) +
geom_bin2d()
A relative price of zero means that the diamond was at the average price; positive means that it’s more expensive than expected (based on its size), and negative means that it’s cheaper than expected.
Interpreting the values precisely is a little tricky here because we’ve log-transformed price. The residuals give the absolute difference (\(x - expected\)), but here we have \(\log_2(price) - \log_2(expected price)\), or equivalently \(\log_2(price / expected price)\). If we “back-transform” to the original scale by applying the opposite transformation (\(2 ^ x\)) we get \(price / expected price\). This makes the values more interpretable, at the cost of the nice symmetry property of the logged values (i.e. both relatively cheaper and relatively more expensive diamonds have the same range). We can make a little table to help interpret the values:
xgrid <- seq(-2, 1, by = 1/3)
data.frame(logx = xgrid, x = round(2 ^ xgrid, 2))
#> logx x
#> 1 -2.000 0.25
#> 2 -1.667 0.31
#> 3 -1.333 0.40
#> 4 -1.000 0.50
#> 5 -0.667 0.63
#> 6 -0.333 0.79
#> 7 0.000 1.00
#> 8 0.333 1.26
#> 9 0.667 1.59
#> 10 1.000 2.00
This table illustrates why we used log2()
rather than log()
: a change of 1 unit on the logged scale, corresponding to a doubling on the original scale. For example, a rel_price
of -1 means that it’s half of the expected price; a relative price of 1 means that it’s twice the expected price.
Let’s use both price and relative price to see how colour and cut affect the value of a diamond. We’ll compute the average price and average relative price for each combination of colour and cut:
color_cut <- diamonds2 %>%
group_by(color, cut) %>%
summarise(
price = mean(price),
rel_price = mean(rel_price)
)
color_cut
#> Source: local data frame [35 x 4]
#> Groups: color [?]
#>
#> color cut price rel_price
#> <ord> <ord> <dbl> <dbl>
#> 1 D Fair 3939 -0.0755
#> 2 D Good 3309 -0.0472
#> 3 D Very Good 3368 0.1038
#> 4 D Premium 3513 0.1093
#> 5 D Ideal 2595 0.2173
#> 6 E Fair 3516 -0.1720
#> # ... with 29 more rows
If we look at price, it’s hard to see how the quality of the diamond affects the price. The lowest quality diamonds (fair cut with colour J) have the highest average value! This is because those diamonds also tend to be larger: size and quality are confounded.
ggplot(color_cut, aes(color, price)) +
geom_line(aes(group = cut), colour = "grey80") +
geom_point(aes(colour = cut))
If however, we plot the relative price, you see the pattern that you expect: as the quality of the diamonds decreases, the relative price decreases. The worst quality diamond is 0.61x (\(2 ^ {-0.7}\)) the price of an “average” diamond.
ggplot(color_cut, aes(color, rel_price)) +
geom_line(aes(group = cut), colour = "grey80") +
geom_point(aes(colour = cut))
This technique can be employed in a wide range of situations. Wherever you can explicitly model a strong pattern that you see in a plot, it’s worthwhile to use a model to remove that strong pattern so that you can see what interesting trends remain.
What happens if you repeat the above analysis with all diamonds? (Not just all diamonds with two or fewer carats). What does the strange geometry of log(carat)
vs relative price represent? What does the diagonal line without any points represent?
I made an unsupported assertion that lower-quality diamonds tend to be larger. Support my claim with a plot.
Can you create a plot that simultaneously shows the effect of colour, cut, and clarity on relative price? If there’s too much information to show on one plot, think about how you might create a sequence of plots to convey the same message.
How do depth and table relate to the relative price?
We’ll continue to explore the connection between modelling and visualisation with the txhousing
dataset:
txhousing
#> # A tibble: 8,034 × 9
#> city year month sales volume median listings inventory date
#> <chr> <int> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 Abilene 2000 1 72 5380000 71400 701 6.3 2000
#> 2 Abilene 2000 2 98 6505000 58700 746 6.6 2000
#> 3 Abilene 2000 3 130 9285000 58100 784 6.8 2000
#> 4 Abilene 2000 4 98 9730000 68600 785 6.9 2000
#> 5 Abilene 2000 5 141 10590000 67300 794 6.8 2000
#> 6 Abilene 2000 6 156 13910000 66900 780 6.6 2000
#> # ... with 8,028 more rows
This data was collected by the Real Estate Center at Texas A&M University, http://recenter.tamu.edu/Data/hs/. The data contains information about 46 Texas cities, recording the number of house sales (sales
), the total volume of sales (volume
), the average
and median
sale prices, the number of houses listed for sale (listings
) and the number of months inventory (inventory
). Data is recorded monthly from Jan 2000 to Apr 2015, 187 entries for each city.
We’re going to explore how sales have varied over time for each city as it shows some interesting trends and poses some interesting challenges. Let’s start with an overview: a time series of sales for each city:
ggplot(txhousing, aes(date, sales)) +
geom_line(aes(group = city), alpha = 1/2)
Two factors make it hard to see the long-term trend in this plot:
The range of sales varies over multiple orders of magnitude. The biggest city, Houston, averages over ~4000 sales per month; the smallest city, San Marcos, only averages ~20 sales per month.
There is a strong seasonal trend: sales are much higher in the summer than in the winter.
We can fix the first problem by plotting log sales:
ggplot(txhousing, aes(date, log(sales))) +
geom_line(aes(group = city), alpha = 1/2)
We can fix the second problem using the same technique we used for removing the trend in the diamonds data: we’ll fit a linear model and look at the residuals. This time we’ll use a categorical predictor to remove the month effect. First we check that the technique works by applying it to a single city. It’s always a good idea to start simple so that if something goes wrong you can more easily pinpoint the problem.
abilene <- txhousing %>% filter(city == "Abilene")
ggplot(abilene, aes(date, log(sales))) +
geom_line()
mod <- lm(log(sales) ~ factor(month), data = abilene)
abilene$rel_sales <- resid(mod)
ggplot(abilene, aes(date, rel_sales)) +
geom_line()
We can apply this transformation to every city with group_by()
and mutate()
. Note the use of na.action = na.exclude
argument to lm()
. Counterintuitively this ensures that missing values in the input are matched with missing values in the output predictions and residuals. Without this argument, missing values are just dropped, and the residuals don’t line up with the inputs.
deseas <- function(x, month) {
resid(lm(x ~ factor(month), na.action = na.exclude))
}
txhousing <- txhousing %>%
group_by(city) %>%
mutate(rel_sales = deseas(log(sales), month))
With this data in hand, we can re-plot the data. Now that we have log-transformed the data and removed the strong seasonal effects we can see there is a strong common pattern: a consistent increase from 2000-2007, a drop until 2010 (with quite some noise), and then a gradual rebound. To make that more clear, I included a summary line that shows the mean relative sales across all cities.
ggplot(txhousing, aes(date, rel_sales)) +
geom_line(aes(group = city), alpha = 1/5) +
geom_line(stat = "summary", fun.y = "mean", colour = "red")
(Note that removing the seasonal effect also removed the intercept - we see the trend for each city relative to its average number of sales.)
The final plot shows a lot of short-term noise in the overall trend. How could you smooth this further to focus on long-term changes?
If you look closely (e.g. + xlim(2008, 2012)
) at the long-term trend you’ll notice a weird pattern in 2009-2011. It looks like there was a big dip in 2010. Is this dip “real”? (i.e. can you spot it in the original data)
What other variables in the TX housing data show strong seasonal effects? Does this technique help to remove them?
Not all the cities in this data set have complete time series. Use your dplyr skills to figure out how much data each city is missing. Display the results with a visualisation.
Replicate the computation that stat_summary()
did with dplyr so you can plot the data “by hand”.
The previous examples used the linear model just as a tool for removing trend: we fit the model and immediately threw it away. We didn’t care about the model itself, just what it could do for us. But the models themselves contain useful information and if we keep them around, there are many new problems that we can solve:
We might be interested in cities where the model didn’t fit well: a poorly fitting model suggests that there isn’t much of a seasonal pattern, which contradicts our implicit hypothesis that all cities share a similar pattern.
The coefficients themselves might be interesting. In this case, looking at the coefficients will show us how the seasonal pattern varies between cities.
We may want to dive into the details of the model itself, and see exactly what it says about each observation. For this data, it might help us find suspicious data points that might reflect data entry errors.
To take advantage of this data, we need to store the models. We can do this using a new dplyr verb: do()
. It allows us to store the result of arbitrary computation in a column. Here we’ll use it to store that linear model:
models <- txhousing %>%
group_by(city) %>%
do(mod = lm(
log2(sales) ~ factor(month),
data = .,
na.action = na.exclude
))
models
#> Source: local data frame [46 x 2]
#> Groups: <by row>
#>
#> # A tibble: 46 × 2
#> city mod
#> * <chr> <list>
#> 1 Abilene <S3: lm>
#> 2 Amarillo <S3: lm>
#> 3 Arlington <S3: lm>
#> 4 Austin <S3: lm>
#> 5 Bay Area <S3: lm>
#> 6 Beaumont <S3: lm>
#> # ... with 40 more rows
There are two important things to note in this code:
do()
creates a new column called mod.
This is a special type of column: instead of containing an atomic vector (a logical, integer, numeric, or character) like usual, it’s a list. Lists are R’s most flexible data structure and can hold anything, including linear models.
.
is a special pronoun used by do()
. It refers to the “current” data frame. In this example, do()
fits the model 46 times (once for each city), each time replacing .
with the data for one city.
If you’re an experienced modeller, you might wonder why I didn’t fit one model to all cities simultaneously. That’s a great next step, but it’s often useful to start off simple. Once we have a model that works for each city individually, you can figure out how to generalise it to fit all cities simultaneously.
To visualise these models, we’ll turn them into tidy data frames. We’ll do that with the broom package by David Robinson.
library(broom)
Broom provides three key verbs, each corresponding to one of the challenges outlined above:
glance()
extracts model-level summaries with one row of data for each model. It contains summary statistics like the \(R^2\) and degrees of freedom.
tidy()
extracts coefficient-level summaries with one row of data for each coefficient in each model. It contains information about individual coefficients like their estimate and standard error.
augment()
extracts observation-level summaries with one row of data for each observation in each model. It includes variables like the residual and influence metrics useful for diagnosing outliers.
We’ll learn more about each of these functions in the following three sections.
We’ll begin by looking at how well the model fit to each city with glance()
:
model_sum <- models %>% glance(mod)
model_sum
#> Source: local data frame [46 x 12]
#> Groups: city [46]
#>
#> city r.squared adj.r.squared sigma statistic p.value df logLik
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <int> <dbl>
#> 1 Abilene 0.530 0.500 0.282 17.9 1.50e-23 12 -22.2
#> 2 Amarillo 0.449 0.415 0.302 13.0 7.41e-18 12 -35.0
#> 3 Arlington 0.513 0.483 0.267 16.8 2.75e-22 12 -12.5
#> 4 Austin 0.487 0.455 0.310 15.1 2.04e-20 12 -40.3
#> 5 Bay Area 0.555 0.527 0.265 19.9 1.45e-25 12 -10.5
#> 6 Beaumont 0.430 0.395 0.275 12.0 1.18e-16 12 -18.0
#> # ... with 40 more rows, and 4 more variables: AIC <dbl>, BIC <dbl>,
#> # deviance <dbl>, df.residual <int>
This creates a variable with one row for each city, and variables that either summarise complexity (e.g. df
) or fit (e.g. r.squared
, p.value
, AIC
). Since all the models we fit have the same complexity (12 terms: one for each month), we’ll focus on the model fit summaries. \(R^2\) is a reasonable place to start because it’s well known. We can use a dot plot to see the variation across cities:
ggplot(model_sum, aes(r.squared, reorder(city, r.squared))) +
geom_point()
It’s hard to picture exactly what those values of \(R^2\) mean, so it’s helpful to pick out a few exemplars. The following code extracts and plots out the three cities with the highest and lowest \(R^2\):
top3 <- c("Bryan-College Station", "Lubbock", "NE Tarrant County")
bottom3 <- c("McAllen", "Brownsville", "Harlingen")
extreme <- txhousing %>% ungroup() %>%
filter(city %in% c(top3, bottom3), !is.na(sales)) %>%
mutate(city = factor(city, c(top3, bottom3)))
ggplot(extreme, aes(month, log(sales))) +
geom_line(aes(group = year)) +
facet_wrap(~city)
The cities with low \(R^2\) have weaker seasonal patterns and more variation between years. The data for Harlingen seems particularly noisy.
Do your conclusions change if you use a different measurement of model fit like AIC or deviance? Why/why not?
One possible hypothesis that explains why McAllen, Harlingen and Brownsville have lower \(R^2\) is that they’re smaller towns so there are fewer sales and more noise. Confirm or refute this hypothesis.
McAllen, Harlingen and Brownsville seem to have much more year-to-year variation than Bryan-College Station, Lubbock, and NE Tarrant County. How does the model change if you also include a linear trend for year? (i.e. log(sales) ~ factor(month) + year
).
Create a faceted plot that shows the seasonal patterns for all cities.
Order the facets by the \(R^2\) for the city.
The model fit summaries suggest that there are some important differences in seasonality between the different cities. Let’s dive into those differences by using tidy()
to extract detail about each individual coefficient:
coefs <- models %>% tidy(mod)
coefs
#> Source: local data frame [552 x 6]
#> Groups: city [46]
#>
#> city term estimate std.error statistic p.value
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 Abilene (Intercept) 6.542 0.0704 92.88 7.90e-151
#> 2 Abilene factor(month)2 0.354 0.0996 3.55 4.91e-04
#> 3 Abilene factor(month)3 0.675 0.0996 6.77 1.83e-10
#> 4 Abilene factor(month)4 0.749 0.0996 7.52 2.76e-12
#> 5 Abilene factor(month)5 0.916 0.0996 9.20 1.06e-16
#> 6 Abilene factor(month)6 1.002 0.0996 10.06 4.37e-19
#> # ... with 546 more rows
We’re more interested in the month effect, so we’ll do a little extra tidying to only look at the month coefficients, and then to extract the month value into a numeric variable:
months <- coefs %>%
filter(grepl("factor", term)) %>%
tidyr::extract(term, "month", "(\\d+)", convert = TRUE)
months
#> Source: local data frame [506 x 6]
#> Groups: city [46]
#>
#> city month estimate std.error statistic p.value
#> * <chr> <int> <dbl> <dbl> <dbl> <dbl>
#> 1 Abilene 2 0.354 0.0996 3.55 4.91e-04
#> 2 Abilene 3 0.675 0.0996 6.77 1.83e-10
#> 3 Abilene 4 0.749 0.0996 7.52 2.76e-12
#> 4 Abilene 5 0.916 0.0996 9.20 1.06e-16
#> 5 Abilene 6 1.002 0.0996 10.06 4.37e-19
#> 6 Abilene 7 0.954 0.0996 9.58 9.81e-18
#> # ... with 500 more rows
This is a common pattern. You need to use your data tidying skills at many points in an analysis. Once you have the correct tidy dataset, creating the plot is usually easy. Here we’ll put month on the x-axis, estimate on the y-axis, and draw one line for each city. I’ve back-transformed to make the coefficients more interpretable: these are now ratios of sales compared to January.
ggplot(months, aes(month, 2 ^ estimate)) +
geom_line(aes(group = city))
The pattern seems similar across the cities. The main difference is the strength of the seasonal effect. Let’s pull that out and plot it:
coef_sum <- months %>%
group_by(city) %>%
summarise(max = max(estimate))
ggplot(coef_sum, aes(2 ^ max, reorder(city, max))) +
geom_point()
The cities with the strongest seasonal effect are College Station and San Marcos (both college towns) and Galveston and South Padre Island (beach cities). It makes sense that these cities would have very strong seasonal effects.
Pull out the three cities with highest and lowest seasonal effect. Plot their coefficients.
How does strength of seasonal effect relate to the \(R^2\) for the model? Answer with a plot.
You should be extra cautious when your results agree with your prior beliefs. How can you confirm or refute my hypothesis about the causes of strong seasonal patterns?
Group the diamonds data by cut, clarity and colour. Fit a linear model log(price) ~ log(carat)
. What does the intercept tell you? What does the slope tell you? How do the slope and intercept vary across the groups? Answer with a plot.
Observation-level data, which include residual diagnostics, is most useful in the traditional model fitting scenario, because it can helps you find “high-leverage” points, point that have a big influence on the final model. It’s also useful in conjunction with visualisation, particularly because it provides an alternative way to access the residuals.
Extracting observation-level data is the job of the augment()
function. This adds one row for each observation. It includes the variables used in the original model, the residuals, and a number of common influence statistics (see ?augment.lm
for more details):
obs_sum <- models %>% augment(mod)
obs_sum
#> Source: local data frame [8,034 x 10]
#> Groups: city [46]
#>
#> city log2.sales. factor.month. .fitted .se.fit .resid .hat .sigma
#> <chr> <dbl> <fctr> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 Abilene 6.17 1 6.54 0.0704 -0.372 0.0625 0.281
#> 2 Abilene 6.61 2 6.90 0.0704 -0.281 0.0625 0.282
#> 3 Abilene 7.02 3 7.22 0.0704 -0.194 0.0625 0.282
#> 4 Abilene 6.61 4 7.29 0.0704 -0.676 0.0625 0.278
#> 5 Abilene 7.14 5 7.46 0.0704 -0.319 0.0625 0.281
#> 6 Abilene 7.29 6 7.54 0.0704 -0.259 0.0625 0.282
#> # ... with 8,028 more rows, and 2 more variables: .cooksd <dbl>,
#> # .std.resid <dbl>
For example, it might be interesting to look at the distribution of standardised residuals. (These are residuals standardised to have a variance of one in each model, making them more comparable). We’re looking for unusual values that might need deeper exploration:
ggplot(obs_sum, aes(.std.resid)) +
geom_histogram(binwidth = 0.1)
ggplot(obs_sum, aes(abs(.std.resid))) +
geom_histogram(binwidth = 0.1)
A threshold of 2 seems like a reasonable threshold to explore individually:
obs_sum %>%
filter(abs(.std.resid) > 2) %>%
group_by(city) %>%
summarise(n = n(), avg = mean(abs(.std.resid))) %>%
arrange(desc(n))
#> # A tibble: 43 × 3
#> city n avg
#> <chr> <int> <dbl>
#> 1 Texarkana 12 2.43
#> 2 Harlingen 11 2.73
#> 3 Waco 11 2.96
#> 4 Victoria 10 2.49
#> 5 Brazoria County 9 2.31
#> 6 Brownsville 9 2.48
#> # ... with 37 more rows
In a real analysis, you’d want to look into these cities in more detail.
A common diagnotic plot is fitted values (.fitted
) vs. residuals (.resid
). Do you see any patterns? What if you include the city or month on the same plot?
Create a time series of log(sales) for each city. Highlight points that have a standardised residual of greater than 2.
Wickham, Hadley, Dianne Cook, and Heike Hofmann. 2015. “Visualizing Statistical Models: Removing the Blindfold.” Statistical Analysis and Data Mining: The ASA Data Science Journal 8 (4): 203–25.