Interaction Plot

TJ Mahr
December 13, 2016

library(ggplot2)
library(modelr)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

The quick way (for data exploration)

# Fit and plot a separate model for each aesthetic group (here: each color)
ggplot(iris) +
  aes(x = Sepal.Length, y = Sepal.Width, color = Species) +
  geom_point() +
  stat_smooth(method = "lm")

A more correct way that uses a single model

model <- lm(Sepal.Width ~ Sepal.Length * Species, iris)
summary(model)
#> 
#> Call:
#> lm(formula = Sepal.Width ~ Sepal.Length * Species, data = iris)
#> 
#> Residuals:
#>      Min       1Q   Median       3Q      Max 
#> -0.72394 -0.16327 -0.00289  0.16457  0.60954 
#> 
#> Coefficients:
#>                                Estimate Std. Error t value Pr(>|t|)    
#> (Intercept)                     -0.5694     0.5539  -1.028 0.305622    
#> Sepal.Length                     0.7985     0.1104   7.235 2.55e-11 ***
#> Speciesversicolor                1.4416     0.7130   2.022 0.045056 *  
#> Speciesvirginica                 2.0157     0.6861   2.938 0.003848 ** 
#> Sepal.Length:Speciesversicolor  -0.4788     0.1337  -3.582 0.000465 ***
#> Sepal.Length:Speciesvirginica   -0.5666     0.1262  -4.490 1.45e-05 ***
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> 
#> Residual standard error: 0.2723 on 144 degrees of freedom
#> Multiple R-squared:  0.6227, Adjusted R-squared:  0.6096 
#> F-statistic: 47.53 on 5 and 144 DF,  p-value: < 2.2e-16

The main approach is to use predict(model, new_data) to get the model's predictions for new data.

First we have to make new_data, which is a data-frame with a sample of values from each predictor variable. modelr's data_grid() function can do this.

new_data <- data_grid(iris, Sepal.Length, Species)
new_data
#> # A tibble: 105 × 2
#>    Sepal.Length    Species
#>           <dbl>     <fctr>
#> 1           4.3     setosa
#> 2           4.3 versicolor
#> 3           4.3  virginica
#> 4           4.4     setosa
#> 5           4.4 versicolor
#> 6           4.4  virginica
#> 7           4.5     setosa
#> 8           4.5 versicolor
#> 9           4.5  virginica
#> 10          4.6     setosa
#> # ... with 95 more rows

Now we do the prediction and combine the predictions with new_data.

new_data_fits <- predict(model, new_data, interval = "confidence") %>%
  as_data_frame() %>%
  rename(Sepal.Width = fit)
new_data_fits
#> # A tibble: 105 × 3
#>    Sepal.Width      lwr      upr
#>          <dbl>    <dbl>    <dbl>
#> 1     2.864239 2.692433 3.036045
#> 2     2.246939 1.991598 2.502281
#> 3     2.443435 2.156459 2.730410
#> 4     2.944092 2.791537 3.096647
#> 5     2.278911 2.037749 2.520073
#> 6     2.466624 2.191290 2.741958
#> 7     3.023945 2.889852 3.158037
#> 8     2.310883 2.083808 2.537958
#> 9     2.489813 2.226079 2.753546
#> 10    3.103798 2.987006 3.220589
#> # ... with 95 more rows

# Combine the two data-frames above
model_predictions <- bind_cols(new_data, new_data_fits)
model_predictions
#> # A tibble: 105 × 5
#>    Sepal.Length    Species Sepal.Width      lwr      upr
#>           <dbl>     <fctr>       <dbl>    <dbl>    <dbl>
#> 1           4.3     setosa    2.864239 2.692433 3.036045
#> 2           4.3 versicolor    2.246939 1.991598 2.502281
#> 3           4.3  virginica    2.443435 2.156459 2.730410
#> 4           4.4     setosa    2.944092 2.791537 3.096647
#> 5           4.4 versicolor    2.278911 2.037749 2.520073
#> 6           4.4  virginica    2.466624 2.191290 2.741958
#> 7           4.5     setosa    3.023945 2.889852 3.158037
#> 8           4.5 versicolor    2.310883 2.083808 2.537958
#> 9           4.5  virginica    2.489813 2.226079 2.753546
#> 10          4.6     setosa    3.103798 2.987006 3.220589
#> # ... with 95 more rows

We are going to plot both the original data and the model predictions.

# First do the raw data.
ggplot(iris) +
  aes(x = Sepal.Length, y = Sepal.Width, color = Species) +
  geom_point() +
  # Then the ribbons. Note the change in data source
  geom_ribbon(aes(ymin = lwr, ymax = upr, color = NULL, group = Species),
              data = model_predictions, alpha = .4, fill = "grey60") +
  # Then a line for the predicted means.
  geom_line(data = model_predictions, size = 1)

We can see a problem... We are predicting responses at implausible values for setosa flowers. We fix this by calling data_grid() separately for each species. data_grid() respects the grouping in group_by(), so we can take advantage of that.

new_data2 <- iris %>%
  group_by(Species) %>%
  data_grid(Sepal.Length) %>%
  ungroup()

# Same steps as before...
new_data_fits2 <- predict(model, new_data2, interval = "confidence") %>%
  as_data_frame() %>%
  rename(Sepal.Width = fit)

model_predictions2 <- bind_cols(new_data2, new_data_fits2)

ggplot(iris) +
  aes(x = Sepal.Length, y = Sepal.Width, color = Species) +
  geom_point() +
  geom_ribbon(aes(ymin = lwr, ymax = upr, color = NULL, group = Species),
              data = model_predictions2, alpha = .4, fill = "grey60") +
  geom_line(data = model_predictions2, size = 1)

This is virtually identical to the original quick interaction plot, but this is not always the case.