什么是模型?
数据集 是一个 复杂系统 某些特征的量化体现, 模型 是对一个数据集的低维概括性表示。
建模过程包含哪些步骤?
提取数据,形成数据集;
根据业务领域知识,选择一个(或多个)模型族, 模型族可以是参数化的(例如线性回归),也可以是非参数化的(例如树方法);
用模型族中的每个模型拟合数据集,找到拟合效果最好的那个模型,作为系统建模的最终结果。
这个过程决定了模型具有以下特点:
模型没有对与错之分,只有质量高低之分;
模型族的选择往往比模型拟合过程对模型质量的影响更大。
下面以一个 tidyverse
内置数据集 sim1
为例说明建模的基本过程。
## ── Attaching packages ───────────────────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.2.1 ✓ purrr 0.3.4
## ✓ tibble 3.0.1 ✓ dplyr 0.8.5
## ✓ tidyr 1.0.0 ✓ stringr 1.4.0
## ✓ readr 1.3.1 ✓ forcats 0.5.0
## ── Conflicts ──────────────────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
## # A tibble: 30 x 2
## x y
## <int> <dbl>
## 1 1 4.20
## 2 1 7.51
## 3 1 2.13
## 4 2 8.99
## 5 2 10.2
## 6 2 11.3
## 7 3 7.36
## 8 3 10.5
## 9 3 10.5
## 10 4 12.4
## # … with 20 more rows
从分布形式上看用一阶线性模型比较好,也就是形如 \(y = a_0 + a_1 x\) 的模型, 于是使用均匀分布生成函数 runif
随机生成250个一阶线性模型:
models <- tibble(
a1 = runif(250, -20, 40),
a2 = runif(250, -5, 5)
)
ggplot(sim1, aes(x, y)) +
geom_abline(aes(intercept = a1, slope = a2), data = models, alpha = 1/4) +
geom_point()
其中大多数效果很差(很正常),但其中也有看上去比较靠谱的。
创建模型生成器 model1
,并使用它创建一个 \(a_0 = 7, \; a_1 = 1.5\) 的模型:
## [1] 8.5 8.5 8.5 10.0 10.0 10.0 11.5 11.5 11.5 13.0 13.0 13.0 14.5 14.5
## [15] 14.5 16.0 16.0 16.0 17.5 17.5 17.5 19.0 19.0 19.0 20.5 20.5 20.5 22.0
## [29] 22.0 22.0
计算此模型的 root-mean-squared deviation (RMSD):
measure_distance <- function(mod, data) {
diff <- data$y - model1(mod, data)
sqrt(mean(diff ^ 2))
}
measure_distance(c(7, 1.5), sim1)
## [1] 2.665212
计算所有 250 个模型的 RMSD:
sim1_dist <- function(x1, x2) {
measure_distance(c(x1, x2), sim1)
}
models <- models %>%
mutate(dist = purrr::map2_dbl(a1, a2, sim1_dist))
models
## # A tibble: 250 x 3
## a1 a2 dist
## <dbl> <dbl> <dbl>
## 1 33.0 1.95 28.3
## 2 3.26 -4.98 44.5
## 3 27.2 0.840 16.8
## 4 23.2 -0.809 9.10
## 5 -5.32 -4.91 51.9
## 6 32.9 4.12 40.5
## 7 10.2 -4.87 37.8
## 8 10.1 4.33 19.6
## 9 32.2 2.85 32.6
## 10 2.89 -2.35 28.5
## # … with 240 more rows
这里 purrr::map2_dbl()
是 map2()
函数族的一员,表示并行版本的 map()
函数, 这里的意思是:将sim1_dist(x1, x2)
映射到 models
的 dist
列上, 其中第1个参数 x1
来自 models
的 a1
列,第2个参数 x2
来自 a2
列。
为了便于区别,修改了原代码里 sim1_dist
的名义参数列表。
画出 RMSD 最小的前10个模型,RMSD 值越小拟合效果越好, 对应的线颜色越浅 (用 dist
的相反数表征,相反数越小颜色越深):
ggplot(sim1, aes(x, y)) +
geom_point(size = 2, colour = "grey30") +
geom_abline(
aes(intercept = a1, slope = a2, colour = -dist),
data = filter(models, rank(dist) <= 10)
)
用散点图表示最优10个模型和所有模型:
ggplot(models, aes(a1, a2)) +
geom_point(data = filter(models, rank(dist) <= 10), size = 4, colour = "red") +
geom_point(aes(colour = -dist))
使用均匀分布的二维阵列代替上面的随机数,重新绘制模型质量散点图:
grid <- expand.grid(
a1 = seq(-5, 20, length = 25),
a2 = seq(1, 3, length = 25)
) %>%
mutate(dist = purrr::map2_dbl(a1, a2, sim1_dist))
grid %>%
ggplot(aes(a1, a2)) +
geom_point(data = filter(grid, rank(dist) <= 10), size = 4, colour = "red") +
geom_point(aes(colour = -dist))
二维阵列保存在 grid
中,模型质量保存在 grid$dist
中。
用筛选出来的最好的10个模型绘制拟合图:
ggplot(sim1, aes(x, y)) +
geom_point(size = 2, colour = "grey30") +
geom_abline(
aes(intercept = a1, slope = a2, colour = -dist),
data = filter(grid, rank(dist) <= 10)
)
使用 Newton-Raphson 搜索(由 optim
函数实现)寻找 measure_distance
函数在 sim1
数据集上的最小值:
## [1] 4.222248 2.051204
#> [1] 4.22 2.05
ggplot(sim1, aes(x, y)) +
geom_point(size = 2, colour = "grey30") +
geom_abline(intercept = best$par[1], slope = best$par[2])
这样就完成了系统建模的最后一步:用模型拟合数据,找到最好的那一个。 但这里使用 Newton-Raphson 方法 可能 找到的是局部最优解, 为了确保找到全局最优解,使用 R 提供的 lm()
函数:
## (Intercept) x
## 4.220822 2.051533
全局最优解与使用 optim()
得到的最优解一致。
使用预测值-残差方法分析模型,首先用 data_grid
函数得到数据集的所有值(无重复):
## # A tibble: 10 x 1
## x
## <int>
## 1 1
## 2 2
## 3 3
## 4 4
## 5 5
## 6 6
## 7 7
## 8 8
## 9 9
## 10 10
用 add_predictions()
函数计算模型的预测值:
## # A tibble: 10 x 2
## x pred
## <int> <dbl>
## 1 1 6.27
## 2 2 8.32
## 3 3 10.4
## 4 4 12.4
## 5 5 14.5
## 6 6 16.5
## 7 7 18.6
## 8 8 20.6
## 9 9 22.7
## 10 10 24.7
用 grid$pred
绘制拟合曲线:
ggplot(sim1, aes(x)) +
geom_point(aes(y = y)) +
geom_line(aes(y = pred), data = grid, colour = "red", size = 1)
与上面使用 geom_abline()
绘制拟合曲线相比,这里的计算方法更通用,适用于所有模型。
我们知道模型是对数据的简化概括,所以它反映了数据集的某些特征,忽略了另一些特征, 这些被忽略的特征,就体现在残差里。
使用 add_residuals()
## # A tibble: 30 x 3
## x y resid
## <int> <dbl> <dbl>
## 1 1 4.20 -2.07
## 2 1 7.51 1.24
## 3 1 2.13 -4.15
## 4 2 8.99 0.665
## 5 2 10.2 1.92
## 6 2 11.3 2.97
## 7 3 7.36 -3.02
## 8 3 10.5 0.130
## 9 3 10.5 0.136
## 10 4 12.4 0.00763
## # … with 20 more rows
绘制残差的折线图:
## [1] 1.563194e-13
可以看到残差在 \(X\) 轴两侧出现频率覆盖的面积基本一致,表明残差的和为0, 使用 sum()
函数求残差和验证了上面的假设。
绘制残差分布图:
## # A tibble: 2 x 2
## `(Intercept)` x1
## <dbl> <dbl>
## 1 1 2
## 2 1 1
自制数据集的 model_matrix:
df <- tribble(
~ sex, ~ response,
"male", 1,
"female", 2,
"male", 1
)
model_matrix(df, response ~ sex)
## # A tibble: 3 x 2
## `(Intercept)` sexmale
## <dbl> <dbl>
## 1 1 1
## 2 1 0
## 3 1 1
参考 Design matrix.
sim2 数据集及其线性拟合:
## # A tibble: 40 x 2
## x y
## <chr> <dbl>
## 1 a 1.94
## 2 a 1.18
## 3 a 1.24
## 4 a 2.62
## 5 a 1.11
## 6 a 0.866
## 7 a -0.910
## 8 a 0.721
## 9 a 0.687
## 10 a 2.07
## # … with 30 more rows
## # A tibble: 4 x 2
## x pred
## <chr> <dbl>
## 1 a 1.15
## 2 b 8.12
## 3 c 6.13
## 4 d 1.91
拟合值是每个类别所有 \(Y\) 值的平均值(最小化了 root-mean-squared distance):
ggplot(sim2, aes(x)) +
geom_point(aes(y = y)) +
geom_point(data = grid, aes(y = pred), colour = "red", size = 4)
## # A tibble: 120 x 5
## x1 x2 rep y sd
## <int> <fct> <int> <dbl> <dbl>
## 1 1 a 1 -0.571 2
## 2 1 a 2 1.18 2
## 3 1 a 3 2.24 2
## 4 1 b 1 7.44 2
## 5 1 b 2 8.52 2
## 6 1 b 3 7.72 2
## 7 1 c 1 6.51 2
## 8 1 c 2 5.79 2
## 9 1 c 3 6.07 2
## 10 1 d 1 2.11 2
## # … with 110 more rows
两个模型族:
为两个特征添加预测值:
## # A tibble: 80 x 4
## model x1 x2 pred
## <chr> <int> <fct> <dbl>
## 1 mod1 1 a 1.67
## 2 mod1 1 b 4.56
## 3 mod1 1 c 6.48
## 4 mod1 1 d 4.03
## 5 mod1 2 a 1.48
## 6 mod1 2 b 4.37
## 7 mod1 2 c 6.28
## 8 mod1 2 d 3.84
## 9 mod1 3 a 1.28
## 10 mod1 3 b 4.17
## # … with 70 more rows
线性拟合模型:
ggplot(sim3, aes(x1, y, colour = x2)) +
geom_point() +
geom_line(data = grid, aes(y = pred)) +
facet_wrap(~ model)
对比两个模型的残差:
sim3 <- sim3 %>%
gather_residuals(mod1, mod2)
ggplot(sim3, aes(x1, resid, colour = x2)) +
geom_point() +
facet_grid(model ~ x2)
很明显 mod1 的 b, c, d 项的残差包含了某种模式,说明模型没有包含数据中的所有信息。 所以 mod2 比 mod1 更好。
基于 sim4 生成两个模型,以及各自的预测值:
## # A tibble: 300 x 4
## x1 x2 rep y
## <dbl> <dbl> <int> <dbl>
## 1 -1 -1 1 4.25
## 2 -1 -1 2 1.21
## 3 -1 -1 3 0.353
## 4 -1 -0.778 1 -0.0467
## 5 -1 -0.778 2 4.64
## 6 -1 -0.778 3 1.38
## 7 -1 -0.556 1 0.975
## 8 -1 -0.556 2 2.50
## 9 -1 -0.556 3 2.70
## 10 -1 -0.333 1 0.558
## # … with 290 more rows
mod1 <- lm(y ~ x1 + x2, data = sim4)
mod2 <- lm(y ~ x1 * x2, data = sim4)
grid <- sim4 %>%
data_grid(
x1 = seq_range(x1, 5),
x2 = seq_range(x2, 5)
) %>%
gather_predictions(mod1, mod2)
grid
## # A tibble: 50 x 4
## model x1 x2 pred
## <chr> <dbl> <dbl> <dbl>
## 1 mod1 -1 -1 0.996
## 2 mod1 -1 -0.5 -0.395
## 3 mod1 -1 0 -1.79
## 4 mod1 -1 0.5 -3.18
## 5 mod1 -1 1 -4.57
## 6 mod1 -0.5 -1 1.91
## 7 mod1 -0.5 -0.5 0.516
## 8 mod1 -0.5 0 -0.875
## 9 mod1 -0.5 0.5 -2.27
## 10 mod1 -0.5 1 -3.66
## # … with 40 more rows
绘制拟合图:
似乎差别不大,改为绘制等高线,x2
处于不同区间时 x1
的预测值趋势, 以及 x1
处于不同区间时 x2
的预测值趋势:
由于 mod1
是线性模型,所以各条线之间是平行的, mod2
由于增加了交互项,各条线之间不再平行。
这里为什么要用 group
参数? 如果去掉这个参数,效果如下:
可以看到,由于二维绘图无法展示包含两个特征的数据集, 如果不使用 group
参数,geom_line()
会把 x1
作为唯一自变量, 将所有点连在一起,这显然是不合理的(上面第一张图)。 第二张图展示了数据的实际形态,解释了第一张图形成的原因。
解决方法是使用 group
参数将隐藏的连续型特征 x2
转为类别变量(使用 binage 方法), 然后分组 (group) 绘制(上面第三张图)。 使用分组虽然能绘制 x2
各种情况下 x1
和 pred
之间的关系, 但多条线之间没有视觉上的区分,所以一般与 color
或者 shape
联合使用。
通过在线性模型上施加变换 (transformation),可以方便地将上面的建模技术扩展到非线性领域。
我们知道通过泰勒级数展开可以用多项式拟合任何连续函数,而多项式又是线性模型的一种, 下面是一个自然样条拟合的例子:
## # A tibble: 3 x 3
## `(Intercept)` `ns(x, 2)1` `ns(x, 2)2`
## <dbl> <dbl> <dbl>
## 1 1 0 0
## 2 1 0.566 -0.211
## 3 1 0.344 0.771
最后通过一个自然样条函数拟合三角函数的例子说明非线性建模技术,首先准备好原始数据:
sim5 <- tibble(
x = seq(0, 3.5 * pi, length = 50),
y = 4 * sin(x) + rnorm(length(x))
)
ggplot(sim5, aes(x, y)) +
geom_point()
用不同次数的模型拟合:
mod1 <- lm(y ~ ns(x, 1), data = sim5)
mod2 <- lm(y ~ ns(x, 2), data = sim5)
mod3 <- lm(y ~ ns(x, 3), data = sim5)
mod4 <- lm(y ~ ns(x, 4), data = sim5)
mod5 <- lm(y ~ ns(x, 5), data = sim5)
grid <- sim5 %>%
data_grid(x = seq_range(x, n = 50, expand = 0.1)) %>%
gather_predictions(mod1, mod2, mod3, mod4, mod5, .pred = "y")
ggplot(sim5, aes(x, y)) +
geom_point() +
geom_line(data = grid, colour = "red") +
facet_wrap(~ model)
空值不能传递任何变量间有价值的信息,所以 R 默认剔除数据中的空值。 如果需要遇到空值后提出警告,而不是直接删除,可以通过设置 options(na.action = na.warn)
实现, 本文第一节中设置了这一特征,效果如下所示:
## Warning: Dropping 2 rows with missing values
如果需要在某次建模时关闭警告,可以通过 na.action = na.exclude
实现:
然后用 nobs()
函数查看这个模型使用了多少有效的观测:
## [1] 3
以上以线性模型为例说明了建模过程,线性模型虽然应用十分广泛,但也不是唯一的选择, 下面列出了数据建模中常用的其他几种模型族以及 R 中常用的实现方法:
Generalised linear models: stats::glm()
Generalised additive models: mgcv::gam()
Penalised linear models: glmnet::glmnet()
Robust linear models: MASS:rlm()
Trees: rpart::rpart()
面对包含许多陌生信息的数据集,如何渐进地构造出满足业务要求的数据模型? 按照下面的流程,不断循环迭代,得到一个足够好的模型:
观察现有数据,提出假设;
根据已有假设,通过可视化方法构建初始模型;
从原始数据集中去掉模型可以解释的部分,得到残差;
将残差作为模型,重复前面两个步骤,直到最终残差符合终止条件。
下面通过两个实例说明上述方法在实际数据集上的应用过程。
##
## Attaching package: 'lubridate'
## The following object is masked from 'package:base':
##
## date
三张图提供了类似的信息:
切割工艺 (Fair)、颜色 (J) 和纯度 (I1) 最差的钻石反而价格最高。
为了找到这里“反常”价格的原因,先对数据集随机抽取50个样本:
## # A tibble: 50 x 5
## price carat cut color clarity
## <int> <dbl> <ord> <ord> <ord>
## 1 4472 1.01 Very Good D SI2
## 2 1752 0.61 Premium H VVS2
## 3 2723 0.77 Premium E SI2
## 4 953 0.41 Ideal I IF
## 5 5864 1.01 Premium G VS2
## 6 1025 0.4 Ideal E VS2
## 7 3936 1.06 Ideal G SI2
## 8 526 0.3 Premium E SI1
## 9 3322 0.72 Ideal D SI1
## 10 6788 1.65 Premium G I1
## # … with 40 more rows
通过观察我们发现,价格受到重量 (carat) 的影响比较大,而且似乎比其他因素的影响还要大, 所以当撇开重量谈价格,是没有任何意义的。
现在问题变成了,如果通过建模过程,将上面的假设转换为具体的模型,最终证实或者证否基于感觉得到的结论? 例如:如何量化各因素对价格的影响?
第1步:看一下重量和价格的关系:
二者之间似乎存在非线性关系,为了更好的揭示它们之间的关系, 不妨先去掉特别重的钻石,这类钻石只占总体的0.3%,且容易扭曲整体关系, 然后求二者的对数,看看效果如何:
diamonds2 <- diamonds %>%
filter(carat <= 2.5) %>%
mutate(lprice = log2(price), lcarat = log2(carat))
head(diamonds2)
## # A tibble: 6 x 12
## carat cut color clarity depth table price x y z lprice
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43 8.35
## 2 0.21 Prem… E SI1 59.8 61 326 3.89 3.84 2.31 8.35
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31 8.35
## 4 0.290 Prem… I VS2 62.4 58 334 4.2 4.23 2.63 8.38
## 5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75 8.39
## 6 0.24 Very… J VVS2 62.8 57 336 3.94 3.96 2.48 8.39
## # … with 1 more variable: lcarat <dbl>
现在可以确定,重量确实与价格之间存在直接联系。
第2步:用线性模型体现二者之间的现有关系:
##
## Call:
## lm(formula = lprice ~ lcarat, data = diamonds2)
##
## Residuals:
## Min 1Q Median 3Q Max
## -1.96407 -0.24549 -0.00844 0.23930 1.93486
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 12.193863 0.001969 6194.5 <2e-16 ***
## lcarat 1.681371 0.001936 868.5 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.3767 on 53812 degrees of freedom
## Multiple R-squared: 0.9334, Adjusted R-squared: 0.9334
## F-statistic: 7.542e+05 on 1 and 53812 DF, p-value: < 2.2e-16
把这个模型绘制出来,并叠加到原始数据上:
## # A tibble: 6 x 12
## carat cut color clarity depth table price x y z lprice
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43 8.35
## 2 0.21 Prem… E SI1 59.8 61 326 3.89 3.84 2.31 8.35
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31 8.35
## 4 0.290 Prem… I VS2 62.4 58 334 4.2 4.23 2.63 8.38
## 5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75 8.39
## 6 0.24 Very… J VVS2 62.8 57 336 3.94 3.96 2.48 8.39
## # … with 1 more variable: lcarat <dbl>
grid <- diamonds2 %>%
data_grid(carat = seq_range(carat, 20)) %>%
mutate(lcarat = log2(carat)) %>%
add_predictions(mod_diamond, "lprice") %>%
mutate(price = 2 ^ lprice)
head(grid, 10)
## # A tibble: 10 x 4
## carat lcarat lprice price
## <dbl> <dbl> <dbl> <dbl>
## 1 0.2 -2.32 8.29 313.
## 2 0.321 -1.64 9.44 694.
## 3 0.442 -1.18 10.2 1188.
## 4 0.563 -0.828 10.8 1784.
## 5 0.684 -0.547 11.3 2475.
## 6 0.805 -0.312 11.7 3255.
## 7 0.926 -0.110 12.0 4119.
## 8 1.05 0.0668 12.3 5064.
## 9 1.17 0.225 12.6 6087.
## 10 1.29 0.367 12.8 7184.
ggplot(diamonds2, aes(carat, price)) +
geom_hex(bins = 50) +
geom_line(data = grid, colour = "red", size = 1)
这里首先通过 seq_range(carat, 20)
将重量(carat 列)转换为一个长度为20的等间距向量, 它的最大、最小值等于原始 carat 向量的最大、最小值。 再通过 data_grid
函数生成一个长度为20的 tibble:
## # A tibble: 20 x 1
## carat
## <dbl>
## 1 0.2
## 2 0.321
## 3 0.442
## 4 0.563
## 5 0.684
## 6 0.805
## 7 0.926
## 8 1.05
## 9 1.17
## 10 1.29
## 11 1.41
## 12 1.53
## 13 1.65
## 14 1.77
## 15 1.89
## 16 2.02
## 17 2.14
## 18 2.26
## 19 2.38
## 20 2.5
然后基于这个 carat 生成新的 lcarat (重量的对数):
## # A tibble: 20 x 2
## carat lcarat
## <dbl> <dbl>
## 1 0.2 -2.32
## 2 0.321 -1.64
## 3 0.442 -1.18
## 4 0.563 -0.828
## 5 0.684 -0.547
## 6 0.805 -0.312
## 7 0.926 -0.110
## 8 1.05 0.0668
## 9 1.17 0.225
## 10 1.29 0.367
## 11 1.41 0.496
## 12 1.53 0.615
## 13 1.65 0.725
## 14 1.77 0.827
## 15 1.89 0.922
## 16 2.02 1.01
## 17 2.14 1.10
## 18 2.26 1.17
## 19 2.38 1.25
## 20 2.5 1.32
结合上面的线性模型 mod_diamond,将现有 lcarat 列对应的 lprice 添加进来:
diamonds2 %>%
data_grid(carat = seq_range(carat, 20)) %>%
mutate(lcarat = log2(carat)) %>%
add_predictions(mod_diamond, "lprice")
## # A tibble: 20 x 3
## carat lcarat lprice
## <dbl> <dbl> <dbl>
## 1 0.2 -2.32 8.29
## 2 0.321 -1.64 9.44
## 3 0.442 -1.18 10.2
## 4 0.563 -0.828 10.8
## 5 0.684 -0.547 11.3
## 6 0.805 -0.312 11.7
## 7 0.926 -0.110 12.0
## 8 1.05 0.0668 12.3
## 9 1.17 0.225 12.6
## 10 1.29 0.367 12.8
## 11 1.41 0.496 13.0
## 12 1.53 0.615 13.2
## 13 1.65 0.725 13.4
## 14 1.77 0.827 13.6
## 15 1.89 0.922 13.7
## 16 2.02 1.01 13.9
## 17 2.14 1.10 14.0
## 18 2.26 1.17 14.2
## 19 2.38 1.25 14.3
## 20 2.5 1.32 14.4
最后基于预测值 lprice 生成真实的预测价格 price。
也可以不用 data_grid
简化重量,在原数据集上生成价格预测值:
grid2 <- diamonds2 %>%
add_predictions(mod_diamond, "lprice") %>%
mutate(price = 2 ^ lprice)
ggplot(diamonds2, aes(carat, price)) +
geom_hex(bins = 50) +
geom_line(data = grid2, colour = "red", size = 1)
效果与简化版本完全一致,只是计算量比前者大了很多。
第3步:从现有数据中去除模型可以解释的部分:
diamonds2 <- diamonds2 %>%
add_residuals(mod_diamond, "lresid")
ggplot(diamonds2, aes(lcarat, lresid)) +
geom_hex(bins = 50)
第4步:将残差作为新的模型进行分析:
可以看到,去掉重量影响后,切割工艺、颜色和纯度与价格的关系正常了。
从上面的图形可以进一步量化各个因素对价格的影响, 这里 lresid
的含义是:重量以外的因素对价格造成影响的以2为底的对数,例如: 纯度为 VS2 钻石的 lresid
中位数接近于0,表明可以用 VS2 作为纯度评价标准, 也就是只考虑重量不考虑其他因素时,钻石的平均价格就是 VS2 钻石价格。
纯度为 I1 的钻石的 lresid
中位数接近于 -1,表明由于纯度不佳,相同重量下, I1 钻石的价格只有 VS2 基准钻石价格的二分之一(\(2^{-1}\))。 如果某钻石的 lresid
值为1,则说明高纯度使得其价格是相同重量 VS2 钻石的2倍(\(2^1\))。
现在把颜色、切割工艺和纯度也纳入模型:
加上原来的重量,现在共包含4个特征,为了通过图形展示这个模型:
## # A tibble: 5 x 5
## cut lcarat color clarity pred
## <ord> <dbl> <chr> <chr> <dbl>
## 1 Fair -0.515 G VS2 11.2
## 2 Good -0.515 G VS2 11.3
## 3 Very Good -0.515 G VS2 11.4
## 4 Premium -0.515 G VS2 11.4
## 5 Ideal -0.515 G VS2 11.4
这里 .model
的意思是如果模型 mod_diamond2
需要没有明确提供的特征,data_grid
自动填充一个 标准值, 对于数值型特征,取中位数,对于类别型特征,取最大成分(出现最多的那个类别)。
图示切割工艺和价格(对数化处理后)之间的关系:
去除所有4个特征的影响后,残差是这样的:
diamonds2 <- diamonds2 %>%
add_residuals(mod_diamond2, "lresid2")
ggplot(diamonds2, aes(lcarat, lresid2)) +
geom_hex(bins = 50)
主体接近白噪声,说明模型的解释程度令人满意,但少部分数据的 lcarat2
值超过了2, 意味着这些钻石的价格偏差是模型解释正常值的4倍(\(2^2\))。 对于这种情况,常用的方法是把它们筛选出来看一看:
diamonds2 %>%
filter(abs(lresid2) > 1) %>%
add_predictions(mod_diamond2) %>%
mutate(pred = round(2 ^ pred)) %>%
select(price, pred, carat:table, x:z) %>%
arrange(price)
## # A tibble: 16 x 11
## price pred carat cut color clarity depth table x y z
## <int> <dbl> <dbl> <ord> <ord> <ord> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 1013 264 0.25 Fair F SI2 54.4 64 4.3 4.23 2.32
## 2 1186 284 0.25 Premium G SI2 59 60 5.33 5.28 3.12
## 3 1186 284 0.25 Premium G SI2 58.8 60 5.33 5.28 3.12
## 4 1262 2644 1.03 Fair E I1 78.2 54 5.72 5.59 4.42
## 5 1415 639 0.35 Fair G VS2 65.9 54 5.57 5.53 3.66
## 6 1415 639 0.35 Fair G VS2 65.9 54 5.57 5.53 3.66
## 7 1715 576 0.32 Fair F VS2 59.6 60 4.42 4.34 2.61
## 8 1776 412 0.290 Fair F SI1 55.8 60 4.48 4.41 2.48
## 9 2160 314 0.34 Fair F I1 55.8 62 4.72 4.6 2.6
## 10 2366 774 0.3 Very Good D VVS2 60.6 58 4.33 4.35 2.63
## 11 3360 1373 0.51 Premium F SI1 62.7 62 5.09 4.96 3.15
## 12 3807 1540 0.61 Good F SI2 62.5 65 5.36 5.29 3.33
## 13 3920 1705 0.51 Fair F VVS2 65.4 60 4.98 4.9 3.23
## 14 4368 1705 0.51 Fair F VVS2 60.7 66 5.21 5.11 3.13
## 15 10011 4048 1.01 Fair D SI2 64.6 58 6.25 6.2 4.02
## 16 10470 23622 2.46 Premium E SI2 59.7 59 8.82 8.76 5.25
包括如下步骤:
选出价格偏差超过 \([\frac12, 2]\) 倍正常值的钻石;
添加模型预测价格列;
将对数价格转换为原始价格并取整;
只保留价格、预测价格、重量、切割工艺、颜色、纯度等特征;
按实际价格排序;
这些价格异常的钻石,既有被高估也有被低估的,如果我们的模型没有问题,就要检查数据是不是有问题, 如果数据也没有错误,赶紧买进那些物美价廉的钻石吧。
这个数据集记录了 2013 年纽约机场每次航班信息, 观察每天航班数量变化规律:
## # A tibble: 336,776 x 19
## year month day dep_time sched_dep_time dep_delay arr_time
## <int> <int> <int> <int> <int> <dbl> <int>
## 1 2013 1 1 517 515 2 830
## 2 2013 1 1 533 529 4 850
## 3 2013 1 1 542 540 2 923
## 4 2013 1 1 544 545 -1 1004
## 5 2013 1 1 554 600 -6 812
## 6 2013 1 1 554 558 -4 740
## 7 2013 1 1 555 600 -5 913
## 8 2013 1 1 557 600 -3 709
## 9 2013 1 1 557 600 -3 838
## 10 2013 1 1 558 600 -2 753
## # … with 336,766 more rows, and 12 more variables: sched_arr_time <int>,
## # arr_delay <dbl>, carrier <chr>, flight <int>, tailnum <chr>,
## # origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>, hour <dbl>,
## # minute <dbl>, time_hour <dttm>
daily <- flights %>%
mutate(date = make_date(year, month, day)) %>%
group_by(date) %>%
summarise(n = n())
daily
## # A tibble: 365 x 2
## date n
## <date> <int>
## 1 2013-01-01 842
## 2 2013-01-02 943
## 3 2013-01-03 914
## 4 2013-01-04 915
## 5 2013-01-05 720
## 6 2013-01-06 832
## 7 2013-01-07 933
## 8 2013-01-08 899
## 9 2013-01-09 902
## 10 2013-01-10 932
## # … with 355 more rows
由于商务人士是航班顾客的主体,周末的航班数应该比较少, 为了证实这一点,将某天按在一星期中的位置(day of week)分类:
daily <- daily %>%
mutate(wday = wday(date, label = TRUE))
ggplot(daily, aes(wday, n)) +
geom_boxplot()
将线性预测结果叠加到原始数据上:
mod <- lm(n ~ wday, data = daily)
grid <- daily %>%
data_grid(wday) %>%
add_predictions(mod, "n")
ggplot(daily, aes(wday, n)) +
geom_boxplot() +
geom_point(data = grid, colour = "red", size = 4)
去掉周末影响后的偏差分布情况:
## # A tibble: 365 x 4
## date n wday resid
## <date> <int> <ord> <dbl>
## 1 2013-01-01 842 Tue -109.
## 2 2013-01-02 943 Wed -19.7
## 3 2013-01-03 914 Thu -51.8
## 4 2013-01-04 915 Fri -52.5
## 5 2013-01-05 720 Sat -24.6
## 6 2013-01-06 832 Sun -59.5
## 7 2013-01-07 933 Mon -41.8
## 8 2013-01-08 899 Tue -52.4
## 9 2013-01-09 902 Wed -60.7
## 10 2013-01-10 932 Thu -33.8
## # … with 355 more rows
按 day of week 绘制残差曲线:
可以看到如果不考虑特殊日期航班数特别低造成的异常值, 主要是周六的残差不太理想。
筛选出航班明显偏少的日子:
## # A tibble: 11 x 4
## date n wday resid
## <date> <int> <ord> <dbl>
## 1 2013-01-01 842 Tue -109.
## 2 2013-01-20 786 Sun -105.
## 3 2013-05-26 729 Sun -162.
## 4 2013-07-04 737 Thu -229.
## 5 2013-07-05 822 Fri -145.
## 6 2013-09-01 718 Sun -173.
## 7 2013-11-28 634 Thu -332.
## 8 2013-11-29 661 Fri -306.
## 9 2013-12-24 761 Tue -190.
## 10 2013-12-25 719 Wed -244.
## 11 2013-12-31 776 Tue -175.
不难看出主要是节假日期间航班数明显减少,这就给上述异常值了一个比较令人满意的解释。
下面通过拟合平滑曲线观察航班残差的长期变化规律:
daily %>%
ggplot(aes(date, resid)) +
geom_ref_line(h = 0) +
geom_line(colour = "grey50") +
geom_smooth(se = FALSE, span = 0.20)
## `geom_smooth()` using method = 'loess' and formula 'y ~ x'
总体来看,仅仅使用 day of week 预测航班数,存在1月和12月偏低,而5~9月偏高的问题, 我们需要找到更多因素来解释航班数的变化。
只绘制周六航班的变化情况:
daily %>%
filter(wday == "Sat") %>%
ggplot(aes(date, n)) +
geom_point() +
geom_line() +
scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")
周六航班的阶段性增长可能与季节有关, 将它写出函数的形式:
term <- function(date) {
cut(date,
breaks = ymd(20130101, 20130605, 20130825, 20140101),
labels = c("spring", "summer", "fall")
)
}
daily <- daily %>%
mutate(term = term(date))
daily %>%
filter(wday == "Sat") %>%
ggplot(aes(date, n, colour = term)) +
geom_point(alpha = 1/3) +
geom_line() +
scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")
以它为标准分析航班数随 day of week 的变化情况:
按不同的季节分类拟合:
mod1 <- lm(n ~ wday, data = daily)
mod2 <- lm(n ~ wday * term, data = daily)
daily %>%
gather_residuals(without_term = mod1, with_term = mod2) %>%
ggplot(aes(date, resid, colour = model)) +
geom_line(alpha = 0.75)
将季节和 day of week 综合考虑后,残差有减小趋势,但不明显。
将季节和 day of week 综合考虑下的预测航班数与实际航班数叠加展示:
## # A tibble: 21 x 3
## wday term n
## <ord> <fct> <dbl>
## 1 Sun spring 872.
## 2 Sun summer 924.
## 3 Sun fall 895
## 4 Mon spring 961.
## 5 Mon summer 994.
## 6 Mon fall 979.
## 7 Tue spring 940.
## 8 Tue summer 988
## 9 Tue fall 944.
## 10 Wed spring 952.
## # … with 11 more rows
ggplot(daily, aes(wday, n)) +
geom_boxplot() +
geom_point(data = grid, colour = "red") +
facet_wrap(~ term)
预测值类似于某一组合下的平均数,对照原始数据的 box plot 图不难发现,由于异常点的存在,扭曲了整体分布情况,平均值缺乏代表性, 下面我们用对异常值容忍度比较高的 MASS::rlm()
函数再来拟合一次:
mod3 <- MASS::rlm(n ~ wday * term, data = daily)
daily %>%
add_residuals(mod3, "resid") %>%
ggplot(aes(date, resid)) +
geom_hline(yintercept = 0, size = 2, colour = "white") +
geom_line()
与上面的 mod2
模型相比,本模型的残差更贴近于0值,说明此模型比较好的描述了季节和 day-of-week 对航班数的影响。 且由于某些未知因素影响,1、2月份的实际值比预测值偏低。
上面我们用线性模型结合领域知识分析了航班数的变化规律, 如果给模型更多的灵活度,可以直观地从数据中提取更多的规律, 下面我们用自然样条取代线性模型,看看有什么效果:
library(splines)
mod_ns <- MASS::rlm(n ~ wday * ns(date, 5), data = daily)
daily %>%
data_grid(wday, date = seq_range(date, n = 13)) %>%
add_predictions(mod_ns) %>%
ggplot(aes(date, pred, colour = wday)) +
geom_line() +
geom_point()
图中表达的信息与上面的线性模型一致:
工作日航班数显著多于周末航班数,表明航班主体是是商务飞行;
周日的航班显著多于周六,说明很多人需要为周一的工作而提前在周末赶赴工作地点;
从季节上看,秋冬季节的航班少于春夏季节的航班数量。
周六的航班数随季节起伏较大,一个原因是春秋两季学校假期让很多家庭选择坐飞机度假, 另一个原因是圣诞节和元旦周六坐飞机出行的人数很多;
机器学习使用数学工具分析数据集,选择算法拟合模型,最后得出结论,是数据科学家的工作; 数据分析则关注大规模数据的获取、清洗、计算(相当于增强型的 SQL)和展示,是数据工程师的工作。 打个不太恰当的比方,有点像 IT 领域的开发和运维,我们知道这两个领域的融合出现了 devops, 那么机器学习和数据分析是否也能融合在一起,让算法助力数据分析呢?
答案是可以:通过多模型方法实现。 本章介绍了通过多模型方法,结合各种机器学习算法分析大数据集的方法, 具体内容包括:
如何综合运用多个简单模型解释复杂数据集;
使用列表特征(list-column)技术将任何数据存储在 data frame 中,例如用一列保存一个线性模型;
使用 broom 包将模型转换为 tidy data,从而使用各种数据处理技术分析 tidy data;
## # A tibble: 1,704 x 6
## country continent year lifeExp pop gdpPercap
## <fct> <fct> <int> <dbl> <int> <dbl>
## 1 Afghanistan Asia 1952 28.8 8425333 779.
## 2 Afghanistan Asia 1957 30.3 9240934 821.
## 3 Afghanistan Asia 1962 32.0 10267083 853.
## 4 Afghanistan Asia 1967 34.0 11537966 836.
## 5 Afghanistan Asia 1972 36.1 13079460 740.
## 6 Afghanistan Asia 1977 38.4 14880372 786.
## 7 Afghanistan Asia 1982 39.9 12881816 978.
## 8 Afghanistan Asia 1987 40.8 13867957 852.
## 9 Afghanistan Asia 1992 41.7 16317921 649.
## 10 Afghanistan Asia 1997 41.8 22227415 635.
## # … with 1,694 more rows
为每个国家绘制一条 年——预期寿命 关系曲线:
以新西兰为例,使用前面介绍的 模型-残差 的方法分析年代和预期寿命间的关系:
nz <- filter(gapminder, country == "New Zealand")
nz %>%
ggplot(aes(year, lifeExp)) +
geom_line() +
ggtitle("Full data = ")
nz_mod <- lm(lifeExp ~ year, data = nz)
nz %>%
add_predictions(nz_mod) %>%
ggplot(aes(year, pred)) +
geom_line() +
ggtitle("Linear trend + ")
nz %>%
add_residuals(nz_mod) %>%
ggplot(aes(year, resid)) +
geom_hline(yintercept = 0, colour = "white", size = 3) +
geom_line() +
ggtitle("Remaining pattern")
效果不错,现在的问题是,如何为每个国家创建分析模型?
为每个国家创建分析模型,需要从总体数据集中按 country
特征拆分出不同的子数据集, tidyr::nest
函数是个合适的工具:
## # A tibble: 142 x 3
## # Groups: country, continent [710]
## country continent data
## <fct> <fct> <list>
## 1 Afghanistan Asia <tibble [12 × 4]>
## 2 Albania Europe <tibble [12 × 4]>
## 3 Algeria Africa <tibble [12 × 4]>
## 4 Angola Africa <tibble [12 × 4]>
## 5 Argentina Americas <tibble [12 × 4]>
## 6 Australia Oceania <tibble [12 × 4]>
## 7 Austria Europe <tibble [12 × 4]>
## 8 Bahrain Asia <tibble [12 × 4]>
## 9 Bangladesh Asia <tibble [12 × 4]>
## 10 Belgium Europe <tibble [12 × 4]>
## # … with 132 more rows
新增加的特征 data
的每一行都是一个完整的 data frame(更准确地说是 tibble), 例如我们要查看亚洲国家阿富汗的数据:
## # A tibble: 12 x 4
## year lifeExp pop gdpPercap
## <int> <dbl> <int> <dbl>
## 1 1952 28.8 8425333 779.
## 2 1957 30.3 9240934 821.
## 3 1962 32.0 10267083 853.
## 4 1967 34.0 11537966 836.
## 5 1972 36.1 13079460 740.
## 6 1977 38.4 14880372 786.
## 7 1982 39.9 12881816 978.
## 8 1987 40.8 13867957 852.
## 9 1992 41.7 16317921 649.
## 10 1997 41.8 22227415 635.
## 11 2002 42.1 25268405 727.
## 12 2007 43.8 31889923 975.
所以 nest()
的作用是将一个每行是一个观测(一个国家在某一年份的预期寿命)的 data frame 转换成了每行是一个 data frame(某个国家在所有年份中的预期寿命)的 data frame。
要为每个国家创建模型,首先将模型包装在一个函数里:
然后用 purrr::map()
将函数应用到每个列表元素上:
Data frame 最大的优点是能够将相关的信息放在一起, 如果能够将分析模型放到 by_country
中,就实现了将数据和模型整合到了一个 data frame里, 这正好可以通过 dplyr::mutate
函数实现:
## # A tibble: 142 x 4
## # Groups: country, continent [710]
## country continent data model
## <fct> <fct> <list> <list>
## 1 Afghanistan Asia <tibble [12 × 4]> <lm>
## 2 Albania Europe <tibble [12 × 4]> <lm>
## 3 Algeria Africa <tibble [12 × 4]> <lm>
## 4 Angola Africa <tibble [12 × 4]> <lm>
## 5 Argentina Americas <tibble [12 × 4]> <lm>
## 6 Australia Oceania <tibble [12 × 4]> <lm>
## 7 Austria Europe <tibble [12 × 4]> <lm>
## 8 Bahrain Asia <tibble [12 × 4]> <lm>
## 9 Bangladesh Asia <tibble [12 × 4]> <lm>
## 10 Belgium Europe <tibble [12 × 4]> <lm>
## # … with 132 more rows
这里 map
函数的 data
参数表示 by_country$data
。
这样就可以方便的对数据做筛选和排序了:
## # A tibble: 30 x 4
## # Groups: country, continent [710]
## country continent data model
## <fct> <fct> <list> <list>
## 1 Albania Europe <tibble [12 × 4]> <lm>
## 2 Austria Europe <tibble [12 × 4]> <lm>
## 3 Belgium Europe <tibble [12 × 4]> <lm>
## 4 Bosnia and Herzegovina Europe <tibble [12 × 4]> <lm>
## 5 Bulgaria Europe <tibble [12 × 4]> <lm>
## 6 Croatia Europe <tibble [12 × 4]> <lm>
## 7 Czech Republic Europe <tibble [12 × 4]> <lm>
## 8 Denmark Europe <tibble [12 × 4]> <lm>
## 9 Finland Europe <tibble [12 × 4]> <lm>
## 10 France Europe <tibble [12 × 4]> <lm>
## # … with 20 more rows
## # A tibble: 142 x 4
## # Groups: country, continent [710]
## country continent data model
## <fct> <fct> <list> <list>
## 1 Algeria Africa <tibble [12 × 4]> <lm>
## 2 Angola Africa <tibble [12 × 4]> <lm>
## 3 Benin Africa <tibble [12 × 4]> <lm>
## 4 Botswana Africa <tibble [12 × 4]> <lm>
## 5 Burkina Faso Africa <tibble [12 × 4]> <lm>
## 6 Burundi Africa <tibble [12 × 4]> <lm>
## 7 Cameroon Africa <tibble [12 × 4]> <lm>
## 8 Central African Republic Africa <tibble [12 × 4]> <lm>
## 9 Chad Africa <tibble [12 × 4]> <lm>
## 10 Comoros Africa <tibble [12 × 4]> <lm>
## # … with 132 more rows
为每个模型添加残差:
## # A tibble: 142 x 5
## # Groups: country, continent [710]
## country continent data model resids
## <fct> <fct> <list> <list> <list>
## 1 Afghanistan Asia <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 2 Albania Europe <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 3 Algeria Africa <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 4 Angola Africa <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 5 Argentina Americas <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 6 Australia Oceania <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 7 Austria Europe <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 8 Bahrain Asia <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 9 Bangladesh Asia <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 10 Belgium Europe <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## # … with 132 more rows
要绘制每个模型的残差图,首先将嵌套数据集展开成普通数据集:
## # A tibble: 1,704 x 9
## # Groups: country, continent [710]
## country continent data model year lifeExp pop gdpPercap resid
## <fct> <fct> <list> <lis> <int> <dbl> <int> <dbl> <dbl>
## 1 Afghani… Asia <tibble… <lm> 1952 28.8 8.43e6 779. -1.11
## 2 Afghani… Asia <tibble… <lm> 1957 30.3 9.24e6 821. -0.952
## 3 Afghani… Asia <tibble… <lm> 1962 32.0 1.03e7 853. -0.664
## 4 Afghani… Asia <tibble… <lm> 1967 34.0 1.15e7 836. -0.0172
## 5 Afghani… Asia <tibble… <lm> 1972 36.1 1.31e7 740. 0.674
## 6 Afghani… Asia <tibble… <lm> 1977 38.4 1.49e7 786. 1.65
## 7 Afghani… Asia <tibble… <lm> 1982 39.9 1.29e7 978. 1.69
## 8 Afghani… Asia <tibble… <lm> 1987 40.8 1.39e7 852. 1.28
## 9 Afghani… Asia <tibble… <lm> 1992 41.7 1.63e7 649. 0.754
## 10 Afghani… Asia <tibble… <lm> 1997 41.8 2.22e7 635. -0.534
## # … with 1,694 more rows
为这个普通数据集绘制残差图:
resids %>%
ggplot(aes(year, resid)) +
geom_line(aes(group = country), alpha = 1 / 3) +
geom_smooth(se = FALSE)
## `geom_smooth()` using method = 'gam' and formula 'y ~ s(x, bs = "cs")'
按洲分组绘制残差图:
resids %>%
ggplot(aes(year, resid, group = country)) +
geom_line(alpha = 1 / 3) +
facet_wrap(~continent)
不难发现非洲的残差比较高,说明现有的线性模型并不能完美解释这个大洲的预期寿命变化趋势。
##
## Attaching package: 'broom'
## The following object is masked from 'package:modelr':
##
## bootstrap
## # A tibble: 1 x 11
## r.squared adj.r.squared sigma statistic p.value df logLik AIC BIC
## <dbl> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.954 0.949 0.804 205. 5.41e-8 2 -13.3 32.6 34.1
## # … with 2 more variables: deviance <dbl>, df.residual <int>
采用 mutata() + unnest()
可以将上面的方法扩展到整个数据集上:
## # A tibble: 142 x 16
## # Groups: country, continent [710]
## country continent data model resids r.squared adj.r.squared sigma
## <fct> <fct> <lis> <lis> <list> <dbl> <dbl> <dbl>
## 1 Afghan… Asia <tib… <lm> <tibb… 0.948 0.942 1.22
## 2 Albania Europe <tib… <lm> <tibb… 0.911 0.902 1.98
## 3 Algeria Africa <tib… <lm> <tibb… 0.985 0.984 1.32
## 4 Angola Africa <tib… <lm> <tibb… 0.888 0.877 1.41
## 5 Argent… Americas <tib… <lm> <tibb… 0.996 0.995 0.292
## 6 Austra… Oceania <tib… <lm> <tibb… 0.980 0.978 0.621
## 7 Austria Europe <tib… <lm> <tibb… 0.992 0.991 0.407
## 8 Bahrain Asia <tib… <lm> <tibb… 0.967 0.963 1.64
## 9 Bangla… Asia <tib… <lm> <tibb… 0.989 0.988 0.977
## 10 Belgium Europe <tib… <lm> <tibb… 0.995 0.994 0.293
## # … with 132 more rows, and 8 more variables: statistic <dbl>,
## # p.value <dbl>, df <int>, logLik <dbl>, AIC <dbl>, BIC <dbl>,
## # deviance <dbl>, df.residual <int>
去掉其中的列表特征列:
glance <- by_country %>%
mutate(glance = map(model, broom::glance)) %>%
unnest(glance, .drop = TRUE)
## Warning: The `.drop` argument of `unnest()` is deprecated as of tidyr 1.0.0.
## All list-columns are now preserved.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
## # A tibble: 142 x 16
## # Groups: country, continent [710]
## country continent data model resids r.squared adj.r.squared sigma
## <fct> <fct> <lis> <lis> <list> <dbl> <dbl> <dbl>
## 1 Afghan… Asia <tib… <lm> <tibb… 0.948 0.942 1.22
## 2 Albania Europe <tib… <lm> <tibb… 0.911 0.902 1.98
## 3 Algeria Africa <tib… <lm> <tibb… 0.985 0.984 1.32
## 4 Angola Africa <tib… <lm> <tibb… 0.888 0.877 1.41
## 5 Argent… Americas <tib… <lm> <tibb… 0.996 0.995 0.292
## 6 Austra… Oceania <tib… <lm> <tibb… 0.980 0.978 0.621
## 7 Austria Europe <tib… <lm> <tibb… 0.992 0.991 0.407
## 8 Bahrain Asia <tib… <lm> <tibb… 0.967 0.963 1.64
## 9 Bangla… Asia <tib… <lm> <tibb… 0.989 0.988 0.977
## 10 Belgium Europe <tib… <lm> <tibb… 0.995 0.994 0.293
## # … with 132 more rows, and 8 more variables: statistic <dbl>,
## # p.value <dbl>, df <int>, logLik <dbl>, AIC <dbl>, BIC <dbl>,
## # deviance <dbl>, df.residual <int>
对所有模型按质量排序:
## # A tibble: 142 x 16
## # Groups: country, continent [710]
## country continent data model resids r.squared adj.r.squared sigma
## <fct> <fct> <lis> <lis> <list> <dbl> <dbl> <dbl>
## 1 Rwanda Africa <tib… <lm> <tibb… 0.0172 -0.0811 6.56
## 2 Botswa… Africa <tib… <lm> <tibb… 0.0340 -0.0626 6.11
## 3 Zimbab… Africa <tib… <lm> <tibb… 0.0562 -0.0381 7.21
## 4 Zambia Africa <tib… <lm> <tibb… 0.0598 -0.0342 4.53
## 5 Swazil… Africa <tib… <lm> <tibb… 0.0682 -0.0250 6.64
## 6 Lesotho Africa <tib… <lm> <tibb… 0.0849 -0.00666 5.93
## 7 Cote d… Africa <tib… <lm> <tibb… 0.283 0.212 3.93
## 8 South … Africa <tib… <lm> <tibb… 0.312 0.244 4.74
## 9 Uganda Africa <tib… <lm> <tibb… 0.342 0.276 3.19
## 10 Congo,… Africa <tib… <lm> <tibb… 0.348 0.283 2.43
## # … with 132 more rows, and 8 more variables: statistic <dbl>,
## # p.value <dbl>, df <int>, logLik <dbl>, AIC <dbl>, BIC <dbl>,
## # deviance <dbl>, df.residual <int>
似乎所有模型质量差的国家都在非洲,用散点图验证一下:
重点研究模型质量最差(\(R^2 \lt 0.25\))的几个国家:
bad_fit <- filter(glance, r.squared < 0.25)
gapminder %>%
semi_join(bad_fit, by = "country") %>%
ggplot(aes(year, lifeExp, colour = country)) +
geom_line()
不难推测1994年卢旺达种族大屠杀和近几十年艾滋病肆虐可能是造成这些国家人口预期寿命反常下降的重要原因。
R data frame 对 列表特征支持不够好:
## x.1.3 x.3.5
## 1 1 3
## 2 2 4
## 3 3 5
通过 I()
可以实现列表特征,但打印效果不好:
## x y
## 1 1, 2, 3 1, 2
## 2 3, 4, 5 3, 4, 5
tibble
对列表特征的支持比较好,不会自动展开 list:
## # A tibble: 2 x 2
## x y
## <list> <chr>
## 1 <int [3]> 1, 2
## 2 <int [3]> 3, 4, 5
或者使用 tribble
也能达到相同的效果:
## # A tibble: 2 x 2
## x y
## <list> <chr>
## 1 <int [3]> 1, 2
## 2 <int [3]> 3, 4, 5
列表特征一般作为数据处理流程的中间结果,将相关的数据组织在一起,而不是作为 R 函数的直接处理对象。 基于列表特征的工作流程主要由以下3部分组成:
有3种方法创建列表特征列:
tidyr::nest()
:
mutata()
:
summarise
:
注意函数返回结果中,所有元素的类型应该是一致的,虽然数据集本身不会检查元素类型的一致性, 但为了后续使用映射函数不会出现异常,满足这一点是很有必要的。
nest()
函数nest()
可以通过两种方法生成特征列表,第一种是与 group_by
配合使用: 参考 嵌套数据集 中的 by_country
:
## # A tibble: 142 x 5
## # Groups: country, continent [710]
## country continent data model resids
## <fct> <fct> <list> <list> <list>
## 1 Afghanistan Asia <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 2 Albania Europe <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 3 Algeria Africa <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 4 Angola Africa <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 5 Argentina Americas <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 6 Australia Oceania <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 7 Austria Europe <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 8 Bahrain Asia <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 9 Bangladesh Asia <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## 10 Belgium Europe <tibble [12 × 4]> <lm> <tibble [12 × 5]>
## # … with 132 more rows
它的元素包含出 group_by
index 之外的所有特征:
## # A tibble: 12 x 4
## year lifeExp pop gdpPercap
## <int> <dbl> <int> <dbl>
## 1 1952 28.8 8425333 779.
## 2 1957 30.3 9240934 821.
## 3 1962 32.0 10267083 853.
## 4 1967 34.0 11537966 836.
## 5 1972 36.1 13079460 740.
## 6 1977 38.4 14880372 786.
## 7 1982 39.9 12881816 978.
## 8 1987 40.8 13867957 852.
## 9 1992 41.7 16317921 649.
## 10 1997 41.8 22227415 635.
## 11 2002 42.1 25268405 727.
## 12 2007 43.8 31889923 975.
第二种方法是单独使用,将需要嵌套的特征作为参数:
## Warning: All elements of `...` must be named.
## Did you want `data = c(year, lifeExp, pop, gdpPercap)`?
## # A tibble: 12 x 4
## year lifeExp pop gdpPercap
## <int> <dbl> <int> <dbl>
## 1 1952 28.8 8425333 779.
## 2 1957 30.3 9240934 821.
## 3 1962 32.0 10267083 853.
## 4 1967 34.0 11537966 836.
## 5 1972 36.1 13079460 740.
## 6 1977 38.4 14880372 786.
## 7 1982 39.9 12881816 978.
## 8 1987 40.8 13867957 852.
## 9 1992 41.7 16317921 649.
## 10 1997 41.8 22227415 635.
## 11 2002 42.1 25268405 727.
## 12 2007 43.8 31889923 975.
可以看到要得到相同的结果,nest()
参数与前面方法 group_by()
的参数应该是互补的。
使用 mutate()
添加/修改特征时,如果返回的是一个向量(而非标量),就会生成向量特征:
## # A tibble: 2 x 2
## x1 x2
## <chr> <list>
## 1 a,b,c <chr [3]>
## 2 d,e,f,g <chr [4]>
使用 unnest()
展开向量特征,注意展开的方向是竖向的,也就是保持特征數不变,增加观测数:
## Warning: `cols` is now required.
## Please use `cols = c(x2)`
## # A tibble: 7 x 2
## x1 x2
## <chr> <chr>
## 1 a,b,c a
## 2 a,b,c b
## 3 a,b,c c
## 4 d,e,f,g d
## 5 d,e,f,g e
## 6 d,e,f,g f
## 7 d,e,f,g g
最后可是使用 purrr::invoke_map()
函数生成列表特征:
sim <- tribble(
~f, ~params,
"runif", list(min = -1, max = 1),
"rnorm", list(sd = 5),
"rpois", list(lambda = 10)
)
sim %>%
mutate(sims = invoke_map(f, params, n = 10))
## # A tibble: 3 x 3
## f params sims
## <chr> <list> <list>
## 1 runif <named list [2]> <dbl [10]>
## 2 rnorm <named list [1]> <dbl [10]>
## 3 rpois <named list [1]> <int [10]>
注意 sim$sims
不完全是类型一致的,包含了实数向量和整数向量, 但由于实数运算完全覆盖整数运算,所以这样处理是合理的。
summarise()
函数summarise()
函数的经典用法是针对每一个分组生成一个标量形式的汇总值, 例如要获得不同汽缸数各种车型的平均和最大燃油消耗率 (mpg):
## # A tibble: 3 x 2
## cyl mpg_mean
## <dbl> <dbl>
## 1 4 26.7
## 2 6 19.7
## 3 8 15.1
## # A tibble: 3 x 2
## cyl mpg_max
## <dbl> <dbl>
## 1 4 33.9
## 2 6 21.4
## 3 8 19.2
如果我们不仅关系燃油效率,还想知道它的分布情况呢? 使用 quantile()
函数是个好方法,但它返回的是一个向量,不能直接作为 summarise()
函数的参数。 要解决这个问题,向量特征是个不错的工具:
## [[1]]
## 0% 25% 50% 75% 100%
## 21.4 22.8 26.0 30.4 33.9
##
## [[2]]
## 0% 25% 50% 75% 100%
## 17.80 18.65 19.70 21.00 21.40
##
## [[3]]
## 0% 25% 50% 75% 100%
## 10.40 14.40 15.20 16.25 19.20
展开这个数据集:
## Warning: `cols` is now required.
## Please use `cols = c(q)`
## # A tibble: 15 x 2
## cyl q
## <dbl> <dbl>
## 1 4 21.4
## 2 4 22.8
## 3 4 26
## 4 4 30.4
## 5 4 33.9
## 6 6 17.8
## 7 6 18.6
## 8 6 19.7
## 9 6 21
## 10 6 21.4
## 11 8 10.4
## 12 8 14.4
## 13 8 15.2
## 14 8 16.2
## 15 8 19.2
注意只有分组依据 (cyl
) 和 向量特征 (q
)。 quantile()
函数默认采用四分位点,即 0%, 25%, 50%, 75% 和 100%, 但也可以指定分位点位置,然后展开:
probs <- c(0.01, 0.25, 0.5, 0.75, 0.99)
mtcars %>%
group_by(cyl) %>%
summarise(p = list(probs), q = list(quantile(mpg, probs))) %>%
unnest()
## Warning: `cols` is now required.
## Please use `cols = c(p, q)`
## # A tibble: 15 x 3
## cyl p q
## <dbl> <dbl> <dbl>
## 1 4 0.01 21.4
## 2 4 0.25 22.8
## 3 4 0.5 26
## 4 4 0.75 30.4
## 5 4 0.99 33.8
## 6 6 0.01 17.8
## 7 6 0.25 18.6
## 8 6 0.5 19.7
## 9 6 0.75 21
## 10 6 0.99 21.4
## 11 8 0.01 10.4
## 12 8 0.25 14.4
## 13 8 0.5 15.2
## 14 8 0.75 16.2
## 15 8 0.99 19.1
增加的 p
列指明了分位数,提升了数据集的可读性和可操作性。
普通 data frame 的结构相当于一个二维表格,特征名称作为一种 元数据, 不能直接作为普通数据使用,在某些情况下很不方便。 列表特征使我们突破了 data frame 的维数限制:data frame 只体现最高一维, 所有 \(n-1\) 维 打包 在列表特征的元素里。 由于打包隐藏了数据结构的某些特征(主要是向量长度), 使得包含向量特征的数据集比普通 data frame 具有更高的灵活性, 例如下面的 pack_data
数据集,由于每个特征包含长度不同的向量, 无法作为普通的 data frame 处理,通过打包过程变成了 data frame:
## # A tibble: 3 x 2
## name value
## <chr> <list>
## 1 a <int [5]>
## 2 b <int [2]>
## 3 c <int [12]>
而且可以将特征名称 a,b,c
作为函数参数参与计算, 例如使用 str_c()
函数将将特征名称和数组第一个元素连接在一起:
## # A tibble: 3 x 3
## name value smry
## <chr> <list> <chr>
## 1 a <int [5]> a: 1
## 2 b <int [2]> b: 3
## 3 c <int [12]> c: 8
处理完毕包含向量特征的数据集后,需要将结果收集到普通 data frame 中, 根据每个向量最终计算结果形式的不同,存在两种情况:
如果每个向量最终计算结果是一个标量,使用 mutate()
配合 map_lgl()
, map_int()
, map_dbl()
, map_chr()
等函数形成最终 data frame;
如果每个向量最终计算结果仍然是一个向量,使用 unnest()
函数通过重复行的方法得到最终的 data frame。
下面的代码演示了通过 map_chr()
和 map_int()
函数获取向量特征每个元素的类型和长度两个标量, 并分别保存到两个特征中的过程:
df <- tribble(
~x,
letters[1:5],
3:9,
runif(8)
)
df %>% mutate(
type = map_chr(x, typeof),
length = map_int(x, length)
)
## # A tibble: 3 x 3
## x type length
## <list> <chr> <int>
## 1 <chr [5]> character 5
## 2 <int [7]> integer 7
## 3 <dbl [8]> double 8
通过新生成的 类型 特征,可以方便地对多类型列表做按类型筛选。
map_*()
族函数不仅可以应用函数到特征上,例如上面的 map_chr(x, typeof)
, 还可以用于从数据集中取出特定的特征, 例如下面的代码演示了从 df
数据集中取特征 a
和 b
形成新的数据集的方法:
df <- tribble(
~raw,
list(a = 1, b = 2),
list(a = 2, c = 4, d = 5)
)
df %>% mutate(
x = map_dbl(raw, "a"),
y = map_dbl(raw, "b", .null = NA_real_)
)
## # A tibble: 2 x 3
## raw x y
## <list> <dbl> <dbl>
## 1 <named list [2]> 1 2
## 2 <named list [3]> 2 NA
unnest()
展开向量特征的方法是重复普通特征(不是向量特征的列), 每个向量特征的元素成为新的一个观测,例如下面的代码中, 第一个观测 x = 1, y = 1:4
被展开为4个观测:
## # A tibble: 2 x 2
## x y
## <int> <list>
## 1 1 <int [4]>
## 2 2 <dbl [1]>
## Warning: `cols` is now required.
## Please use `cols = c(y)`
## # A tibble: 5 x 2
## x y
## <int> <dbl>
## 1 1 1
## 2 1 2
## 3 1 3
## 4 1 4
## 5 2 1
## # A tibble: 5 x 2
## x y
## <int> <dbl>
## 1 1 1
## 2 1 2
## 3 1 3
## 4 1 4
## 5 2 1
如果没有参数指定要展开的列,unnest()
展开所有的向量特征列。
如果要展开多个向量特征列,要保证每个元素的长度是一样的,否则将导致展开失败, 例如下面的例子中观测 x=1
中,y 和 z 长度不一致:
## # A tibble: 2 x 3
## x y z
## <dbl> <list> <list>
## 1 1 <chr [1]> <int [2]>
## 2 2 <chr [2]> <dbl [1]>
如果长度一致就能展开成功:
## # A tibble: 2 x 3
## x y z
## <dbl> <list> <list>
## 1 1 <chr [2]> <int [2]>
## 2 2 <chr [1]> <dbl [1]>
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(y, z))`, with `mutate()` if needed
## # A tibble: 3 x 3
## x y z
## <dbl> <chr> <dbl>
## 1 1 a 1
## 2 1 b 2
## 3 2 c 3
broom
包归整数据集broom
包主要提供了下列3种方法将包含向量特征的数据集转换为普通数据集:
glance()
方法见 模型质量评估。
下面的代码调用元函数 tidy()
将模型的计算结果转换为数据集,底层是调用了 tidy.lm()
:
## # A tibble: 2 x 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) -308. 26.6 -11.6 0.000000417
## 2 year 0.193 0.0135 14.3 0.0000000541
下面的代码调用元函数 augment()
将模型参数与原有特征整合在一起,底层调用了 augment.lm()
:
## # A tibble: 12 x 9
## lifeExp year .fitted .se.fit .resid .hat .sigma .cooksd .std.resid
## <dbl> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 69.4 1952 68.7 0.437 0.703 0.295 0.801 0.227 1.04
## 2 70.3 1957 69.7 0.381 0.609 0.225 0.816 0.107 0.860
## 3 71.2 1962 70.6 0.331 0.625 0.169 0.816 0.0738 0.852
## 4 71.5 1967 71.6 0.287 -0.0592 0.127 0.848 0.000452 -0.0788
## 5 71.9 1972 72.5 0.253 -0.653 0.0991 0.816 0.0403 -0.856
## 6 72.2 1977 73.5 0.235 -1.29 0.0851 0.719 0.130 -1.67
## 7 73.8 1982 74.5 0.235 -0.632 0.0851 0.819 0.0313 -0.821
## 8 74.3 1987 75.4 0.253 -1.12 0.0991 0.752 0.117 -1.46
## 9 76.3 1992 76.4 0.287 -0.0698 0.127 0.847 0.000627 -0.0928
## 10 77.6 1997 77.4 0.331 0.186 0.169 0.845 0.00655 0.254
## 11 79.1 2002 78.3 0.381 0.782 0.225 0.794 0.177 1.10
## 12 80.2 2007 79.3 0.437 0.912 0.295 0.767 0.381 1.35