So far, we have calculated predicted values using our models visually
(by looking at regression functions) or manually (by adding and
multiplicated our estimated parameters). predict() is a
built-in R function that can do this for you. We will do a short example
here to give you better practice with the syntax:
For this example, we will use the gss_cat dataset, which
has a selection of variables from a survey called the General Social
Survey.
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✔ ggplot2 3.3.6 ✔ purrr 0.3.4
## ✔ tibble 3.1.7 ✔ dplyr 1.0.9
## ✔ tidyr 1.2.0 ✔ stringr 1.4.0
## ✔ readr 2.1.2 ✔ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(modelsummary)
# gss_cat is a dataset built-in to the tidyverse package
head(gss_cat)
## # A tibble: 6 × 9
## year marital age race rincome partyid relig denom tvhours
## <int> <fct> <int> <fct> <fct> <fct> <fct> <fct> <int>
## 1 2000 Never married 26 White $8000 to 9999 Ind,near r… Prot… Sout… 12
## 2 2000 Divorced 48 White $8000 to 9999 Not str re… Prot… Bapt… NA
## 3 2000 Widowed 67 White Not applicable Independent Prot… No d… 2
## 4 2000 Never married 39 White Not applicable Ind,near r… Orth… Not … 4
## 5 2000 Divorced 25 White Not applicable Not str de… None Not … 1
## 6 2000 Married 25 White $20000 - 24999 Strong dem… Prot… Sout… NA
Two of the variables report the number of hours people spend a day
watching TV (tvhours) and their age. Let’s
investigate the relationship between these two:
fit1 <- lm(tvhours ~ age, data = gss_cat)
modelsummary(fit1)
| Model 1 | |
|---|---|
| (Intercept) | 1.992 |
| (0.070) | |
| age | 0.021 |
| (0.001) | |
| Num.Obs. | 11299 |
| R2 | 0.020 |
| R2 Adj. | 0.020 |
| AIC | 53348.0 |
| BIC | 53370.0 |
| Log.Lik. | −26671.018 |
| RMSE | 2.56 |
According to this result, older people are predicted to watch slightly more TV for every year of their age.
So far, we have considered prediction manually. For example, what is the average predicted value of hours of TV watched per day for a 25 year old? In section, we learned this approach:
\[ \widehat{tv\_hours_i} = \widehat{\beta_0} + \widehat{\beta_0} * age_i \] Which you could manually calculate in R with:
f1 <- tidy(fit1)
f1$estimate[1] + (f1$estimate[2] * 25)
## [1] 2.516311
So, our model predicts that the average 25 year old will watch about 2.5 hours of TV per day.
predict()We can have R do this calculation for us automatically.
predict() generally takes two arguments – first your model
object, and then newdata, a data frame or tibble of the
covariate values you would like to predict. Note that your column names
of this tibble must exactly match those you used to fit
your model (e.g. here the variable in gss_cat was called
age, so it must be called age in
newdata – Age would not work).
new_df <- tibble(age = 25)
predict(fit1, newdata = new_df)
## 1
## 2.516311
Beautiful! Even better, the real strength of predict()
is that it can make these predictions for many covariate profiles at the
same time. For example, we can find the predicted tvhours
for every age between 15 and 30 easily. Here, predict()
will return a vector with one prediction for each row of your provided
newdata dataset.
new_df <- tibble(age = 15:30)
predict(fit1, newdata = new_df)
## 1 2 3 4 5 6 7 8
## 2.306743 2.327700 2.348657 2.369613 2.390570 2.411527 2.432484 2.453441
## 9 10 11 12 13 14 15 16
## 2.474397 2.495354 2.516311 2.537268 2.558224 2.579181 2.600138 2.621095
predict() with multiple variablesYou can also predict outcomes for regressions with multiple
variables. Your newdata should have row values for every
predictor in your regression. For example, what if we include a binary
indicator for marriage status into a new regression:
table(gss_cat$marital)
##
## No answer Never married Separated Divorced Widowed
## 17 5416 743 3383 1807
## Married
## 10117
gss_cat$married <- ifelse(gss_cat$marital == "Married", 1, 0)
fit2 <- lm(tvhours ~ age + married, data = gss_cat)
modelsummary(fit2)
| Model 1 | |
|---|---|
| (Intercept) | 2.220 |
| (0.071) | |
| age | 0.023 |
| (0.001) | |
| married | −0.683 |
| (0.048) | |
| Num.Obs. | 11299 |
| R2 | 0.037 |
| R2 Adj. | 0.037 |
| AIC | 53151.1 |
| BIC | 53180.4 |
| Log.Lik. | −26571.540 |
| RMSE | 2.54 |
Great! Now, in order to get predicted values, we need to pass a
tibble() with values for both the age and
married variables. For example, let’s find the predicted
hours of tv watched for an unmarried 19 year old, an unmarried 50 year
old, a married 50 year old, and a married 75 year old:
new_df <- tibble(age = c(19, 50, 50, 75),
married = c(0, 0, 1, 1))
predict(fit2, newdata = new_df)
## 1 2 3 4
## 2.651920 3.357066 2.674183 3.242849