tidymodels

tidymodels 是用于数据科学建模的一系列的包。安装方法install.packages("tidymodels") .tidymodels包含了一系列的R包,这些包括。

  1. rsample 为高效的数据拆分和重采样提供了基础设施
  2. parsnip 是一个整洁、统一的模型接口,可用于尝试一系列模型,而不会陷入底层包的语法细节
  3. recipes 是用于特征工程的数据预处理工具的整洁接口
  4. workflow 将预处理、建模和后处理捆绑在一起
  5. tune 可帮助您优化模型的超参数和预处理步骤
  6. yardstick 使用性能指标来衡量模型的有效性
  7. broom 将常见统计 R 对象中的信息转换为用户友好、可预测的格式。
  8. dials 创建和管理调整参数和参数网格

以上这些包在加载tidymodels 的时候会一起加载,还有很多其他有用的工具需要手动进行加载。

  1. infer是一个高级 API,用于 tidyverse 友好的统计推断。
  2. corrr包具有用于处理相关矩阵的整洁接口
  3. spatialsample包提供重采样函数和类,如 rsample,但专门用于空间数据
  4. tidypredict和modeldb可以将预测方程转换为不同的语言(例如 SQL)并在数据库中拟合一些模型。
  5. stacks包提供了用于堆叠集成建模的工具,可以整合来自许多模型的预测
  6. Finetune包通过更多方法扩展了tune包
  7. usemodels包创建模板并自动生成代码以适应和调整模型。
  8. tidyposterior软件包使用户能够使用重采样和贝叶斯方法在模型之间进行正式的统计比较。
  9. shinymodels可让您通过 Shiny 应用程序探索调整或重新采样结果
  10. hardhat是一个以开发人员为中心的包,可帮助初学者创建用于建模的高质量 R 包
  11. butcher 一些 R 对象在保存到磁盘时变得非常大。butcher包可以通过删除子组件来减小这些对象的大小。

使用tidymodels建模的流程

首先准备相关的包

library(tidymodels)  # for the parsnip package, along with the rest of tidymodels
## Warning: package 'tidymodels' was built under R version 4.1.2
## ── Attaching packages ────────────────────────────────────── tidymodels 0.2.0 ──
## ✓ broom        0.8.0     ✓ recipes      0.2.0
## ✓ dials        0.1.1     ✓ rsample      0.1.1
## ✓ dplyr        1.0.8     ✓ tibble       3.1.6
## ✓ ggplot2      3.3.5     ✓ tidyr        1.2.0
## ✓ infer        1.0.0     ✓ tune         0.2.0
## ✓ modeldata    0.1.1     ✓ workflows    0.2.6
## ✓ parsnip      0.2.1     ✓ workflowsets 0.2.1
## ✓ purrr        0.3.4     ✓ yardstick    0.0.9
## Warning: package 'broom' was built under R version 4.1.2
## Warning: package 'dials' was built under R version 4.1.2
## Warning: package 'scales' was built under R version 4.1.2
## Warning: package 'dplyr' was built under R version 4.1.2
## Warning: package 'parsnip' was built under R version 4.1.2
## Warning: package 'recipes' was built under R version 4.1.2
## Warning: package 'tidyr' was built under R version 4.1.2
## Warning: package 'tune' was built under R version 4.1.2
## Warning: package 'workflows' was built under R version 4.1.2
## Warning: package 'workflowsets' was built under R version 4.1.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## x purrr::discard() masks scales::discard()
## x dplyr::filter()  masks stats::filter()
## x dplyr::lag()     masks stats::lag()
## x recipes::step()  masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
# Helper packages
library(readr)       # for importing data
## Warning: package 'readr' was built under R version 4.1.2
## 
## Attaching package: 'readr'
## The following object is masked from 'package:yardstick':
## 
##     spec
## The following object is masked from 'package:scales':
## 
##     col_factor
library(broom.mixed) # for converting bayesian models to tidy tibbles
## Warning: package 'broom.mixed' was built under R version 4.1.2
library(dotwhisker)  # for visualizing regression results
## Registered S3 method overwritten by 'parameters':
##   method                         from      
##   format.parameters_distribution datawizard

在这个例子中,我们使用的使用的是海胆数据。数据描述了三种不同的喂养方式如何随着时间的推移影响海胆的大小。

urchins <-
  # Data were assembled for a tutorial 
  # at https://www.flutterbys.com.au/stats/tut/tut7.5a.html
  read_csv("https://tidymodels.org/start/models/urchins.csv") %>% 
  # Change the names to be a little more verbose
  setNames(c("food_regime", "initial_volume", "width")) %>% 
  # Factors are very helpful for modeling, so we convert one column
  mutate(food_regime = factor(food_regime, levels = c("Initial", "Low", "High")))
## Rows: 72 Columns: 3
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (1): TREAT
## dbl (2): IV, SUTW
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

urchins 数据集有三个变量,food_regime,initial_volume,width 。其中food_regime表示实验方案,有三种方案,每种方案24个样本;initial_volume 表示实验开始时以毫升为单位的大小;width表示实验结束时候的宽度。

summary(urchins)
##   food_regime initial_volume      width        
##  Initial:24   Min.   : 3.50   Min.   :0.01000  
##  Low    :24   1st Qu.:13.00   1st Qu.:0.05100  
##  High   :24   Median :18.00   Median :0.07100  
##               Mean   :20.88   Mean   :0.07237  
##               3rd Qu.:26.00   3rd Qu.:0.08450  
##               Max.   :47.50   Max.   :0.16300

对数据进行可视化,分组绘制initial_volume 和width 的散点图。

ggplot(urchins,
       aes(x = initial_volume, 
           y = width, 
           group = food_regime, 
           col = food_regime)) + 
  geom_point() + 
  geom_smooth(method = lm, se = FALSE) +
  scale_color_viridis_d(option = "plasma", end = .7) + theme_classic()
## `geom_smooth()` using formula 'y ~ x'

initial_volume 和width 存在正相关关系,但是不同组的相关性大小是不一样的。

这里尝试构建回归模型,使用parsnip包指定我们需要的模型。

linear_reg()
## Linear Regression Model Specification (regression)
## 
## Computational engine: lm

第二步是设置引擎,对于线性回归而言,默认的引擎是lm

linear_reg() %>% 
  set_engine("spark")
## Linear Regression Model Specification (regression)
## 
## Computational engine: spark

可选的引擎包括:

  1. lm
  2. brulee
  3. gee
  4. glm
  5. glmnet
  6. gls
  7. keras
  8. lme
  9. lmer
  10. spark
  11. stan
  12. stan_glmer

引擎制定了训练模型的方式。

然后将引擎进行保存。

lm_mod <- linear_reg()

接着就可以进行训练模型,代码如下所示。

lm_fit <- 
  lm_mod %>% 
  fit(width ~ initial_volume * food_regime, data = urchins)
lm_fit
## parsnip model object
## 
## 
## Call:
## stats::lm(formula = width ~ initial_volume * food_regime, data = data)
## 
## Coefficients:
##                    (Intercept)                  initial_volume  
##                      0.0331216                       0.0015546  
##                 food_regimeLow                 food_regimeHigh  
##                      0.0197824                       0.0214111  
##  initial_volume:food_regimeLow  initial_volume:food_regimeHigh  
##                     -0.0012594                       0.0005254

使用tidy函数查看模型结果。

tidy(lm_fit)
## # A tibble: 6 × 5
##   term                            estimate std.error statistic  p.value
##   <chr>                              <dbl>     <dbl>     <dbl>    <dbl>
## 1 (Intercept)                     0.0331    0.00962      3.44  0.00100 
## 2 initial_volume                  0.00155   0.000398     3.91  0.000222
## 3 food_regimeLow                  0.0198    0.0130       1.52  0.133   
## 4 food_regimeHigh                 0.0214    0.0145       1.47  0.145   
## 5 initial_volume:food_regimeLow  -0.00126   0.000510    -2.47  0.0162  
## 6 initial_volume:food_regimeHigh  0.000525  0.000702     0.748 0.457

使用如下代码可以查看训练数据集的预测结果。

lm_fit$fit$fitted.values
##          1          2          3          4          5          6          7 
## 0.03856287 0.04089483 0.04555875 0.04866803 0.05333195 0.05333195 0.05644123 
##          8          9         10         11         12         13         14 
## 0.05644123 0.05799587 0.05955051 0.06265979 0.06421443 0.06576907 0.06576907 
##         15         16         17         18         19         20         21 
## 0.07043299 0.07043299 0.07043299 0.07665155 0.07820619 0.08753404 0.08908868 
##         22         23         24         25         26         27         28 
## 0.09375260 0.09375260 0.10152580 0.05438034 0.05526615 0.05541379 0.05629960 
##         29         30         31         32         33         34         35 
## 0.05585670 0.05703778 0.05733305 0.05748068 0.05821886 0.05821886 0.05821886 
##         36         37         38         39         40         41         42 
## 0.05939994 0.05925231 0.05895704 0.06028575 0.06058103 0.06412427 0.06501009 
##         43         44         45         46         47         48         49 
## 0.06501009 0.06560063 0.06619117 0.07013299 0.07429306 0.07325304 0.07325304 
##         50         51         52         53         54         55         56 
## 0.07949315 0.08157318 0.08469324 0.08573325 0.08157318 0.07845313 0.08365322 
##         57         58         59         60         61         62         63 
## 0.08365322 0.08677327 0.08573325 0.09093334 0.09405339 0.09405339 0.09509341 
##         64         65         66         67         68         69         70 
## 0.09613343 0.10341355 0.11693378 0.10861364 0.11485374 0.13565409 0.06692934 
##         71         72 
## 0.06663407 0.05629960

模型构建好之后,可以对模型进行预测,代码如下所示。

new_points <- expand.grid(initial_volume = 20, 
                          food_regime = c("Initial", "Low", "High"))

conf_int_pred <- predict(lm_fit, 
                         new_data = new_points, 
                         type = "conf_int")
conf_int_pred
## # A tibble: 3 × 2
##   .pred_lower .pred_upper
##         <dbl>       <dbl>
## 1      0.0555      0.0729
## 2      0.0499      0.0678
## 3      0.0870      0.105

使用recipes 处理数据

首先准备需要使用到的包。

library(tidymodels)      # for the recipes package, along with the rest of tidymodels

# Helper packages
library(nycflights13)    # for flight data
library(skimr)           # for variable summaries
## Warning: package 'skimr' was built under R version 4.1.2

我们使用nycflights13 数据来预测飞机是否晚点超过 30 分钟。该数据集包含 2013 年从纽约市附近起飞的 325,819 次航班的信息。让我们首先加载数据并对变量进行一些更改:

set.seed(123)

flight_data <- 
  flights %>% 
  mutate(
    # Convert the arrival delay to a factor
    arr_delay = ifelse(arr_delay >= 30, "late", "on_time"),
    arr_delay = factor(arr_delay),
    # We will use the date (not date-time) in the recipe below
    date = lubridate::as_date(time_hour)
  ) %>% 
  # Include the weather data
  inner_join(weather, by = c("origin", "time_hour")) %>% 
  # Only retain the specific columns we will use
  select(dep_time, flight, origin, dest, air_time, distance, 
         carrier, date, arr_delay, time_hour) %>% 
  # Exclude missing data
  na.omit() %>% 
  # For creating models, it is better to have qualitative columns
  # encoded as factors (instead of character strings)
  mutate_if(is.character, as.factor)

我们简单查看数据集的情况

glimpse(flight_data)
## Rows: 325,819
## Columns: 10
## $ dep_time  <int> 517, 533, 542, 544, 554, 554, 555, 557, 557, 558, 558, 558, …
## $ flight    <int> 1545, 1714, 1141, 725, 461, 1696, 507, 5708, 79, 301, 49, 71…
## $ origin    <fct> EWR, LGA, JFK, JFK, LGA, EWR, EWR, LGA, JFK, LGA, JFK, JFK, …
## $ dest      <fct> IAH, IAH, MIA, BQN, ATL, ORD, FLL, IAD, MCO, ORD, PBI, TPA, …
## $ air_time  <dbl> 227, 227, 160, 183, 116, 150, 158, 53, 140, 138, 149, 158, 3…
## $ distance  <dbl> 1400, 1416, 1089, 1576, 762, 719, 1065, 229, 944, 733, 1028,…
## $ carrier   <fct> UA, UA, AA, B6, DL, UA, B6, EV, B6, AA, B6, B6, UA, UA, AA, …
## $ date      <date> 2013-01-01, 2013-01-01, 2013-01-01, 2013-01-01, 2013-01-01,…
## $ arr_delay <fct> on_time, on_time, late, on_time, on_time, on_time, on_time, …
## $ time_hour <dttm> 2013-01-01 05:00:00, 2013-01-01 05:00:00, 2013-01-01 05:00:…

我们查看有多少的飞机晚点30分钟。

flight_data %>% 
  count(arr_delay) %>% 
  mutate(prop = n/sum(n))
## # A tibble: 2 × 3
##   arr_delay      n  prop
##   <fct>      <int> <dbl>
## 1 late       52540 0.161
## 2 on_time   273279 0.839

16% 左右的航班会晚点30min。

数据拆分

使用rsample包进行数据的拆分。

set.seed(222)
# Put 3/4 of the data into the training set 
data_split <- initial_split(flight_data, prop = 3/4)

# Create data frames for the two sets:
train_data <- training(data_split)
test_data  <- testing(data_split)

创建配方

使用recipe函数创建配方,配方主要有两个参数,公示和数据

flights_rec <- 
  recipe(arr_delay ~ ., data = train_data) 

flights_rec
## Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor          9

这里与我们构建模型时候的参数设置类似。

进一步我们可以设置role,属于role 的变量可以保留在数据中,但不包含在模型中。当模型拟合后,我们想要调查一些预测不佳的值时,这会很方便。

flights_rec <- 
  recipe(arr_delay ~ ., data = train_data) %>% 
  update_role(flight, time_hour, new_role = "ID") 

summary(flights_rec)
## # A tibble: 10 × 4
##    variable  type    role      source  
##    <chr>     <chr>   <chr>     <chr>   
##  1 dep_time  numeric predictor original
##  2 flight    numeric ID        original
##  3 origin    nominal predictor original
##  4 dest      nominal predictor original
##  5 air_time  numeric predictor original
##  6 distance  numeric predictor original
##  7 carrier   nominal predictor original
##  8 date      date    predictor original
##  9 time_hour date    ID        original
## 10 arr_delay nominal outcome   original

我们还可以进行更多的操作。

flights_rec <- flights_rec  %>% 
  step_date(date, features = c("dow", "month")) %>%               
  step_holiday(date, 
               holidays = timeDate::listHolidays("US"), 
               keep_original_cols = FALSE)
flights_rec
## Recipe
## 
## Inputs:
## 
##       role #variables
##         ID          2
##    outcome          1
##  predictor          7
## 
## Operations:
## 
## Date features from date
## Holiday features from date

首先,使用step_date(),我们创建了两个新的因子列,其中包含适当的星期几和月份。

使用step_holiday(),我们创建了一个二进制变量,指示当前日期是否为假日。的参数值timeDate::listHolidays(“US”)使用timeDate 包列出 17 个标准的美国假期。

使用keep_original_cols = FALSE,我们删除了原始date变量,因为我们不再希望它出现在模型中。许多创建新变量的配方步骤都有这个参数。

接下来我们将所有的分类变量转变成为哑变量。

flights_rec <- flights_rec %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors())

step_zv 函数是为了避免测试集中存在但是训练集中不存在的值。这个函数的主要作用是为了删除只有唯一值的列。

使用recipe拟合模型

我们尝试构建逻辑回归模型预测航班是否晚点。

lr_mod <- 
  logistic_reg() %>% 
  set_engine("glm")

使用模型工作流,它将模型和配方配对在一起

flights_wflow <- 
  workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(flights_rec)

flights_wflow
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
## 
## • step_date()
## • step_holiday()
## • step_dummy()
## • step_zv()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Logistic Regression Model Specification (classification)
## 
## Computational engine: glm

接下来进行拟合模型

flights_fit <- 
  flights_wflow %>% 
  fit(data = train_data)


flights_fit
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
## 
## • step_date()
## • step_holiday()
## • step_dummy()
## • step_zv()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## 
## Call:  stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
## 
## Coefficients:
##                  (Intercept)                      dep_time  
##                     7.276446                     -0.001664  
##                     air_time                      distance  
##                    -0.044014                      0.005071  
##          date_USChristmasDay            date_USColumbusDay  
##                     1.329336                      0.723927  
##     date_USCPulaskisBirthday  date_USDecorationMemorialDay  
##                     0.807165                      0.584694  
##           date_USElectionDay             date_USGoodFriday  
##                     0.947652                      1.246811  
##       date_USInaugurationDay        date_USIndependenceDay  
##                     0.228947                      2.119747  
##              date_USLaborDay       date_USLincolnsBirthday  
##                    -1.933737                      0.583750  
##           date_USMemorialDay        date_USMLKingsBirthday  
##                     1.519217                      0.428585  
##           date_USNewYearsDay          date_USPresidentsDay  
##                     0.204203                      0.483797  
##       date_USThanksgivingDay            date_USVeteransDay  
##                     0.152978                      0.717895  
##   date_USWashingtonsBirthday                    origin_JFK  
##                     0.043305                      0.107289  
##                   origin_LGA                      dest_ACK  
##                     0.010961                     -1.737954  
##                     dest_ALB                      dest_ANC  
##                    -1.679551                     -1.203995  
##                     dest_ATL                      dest_AUS  
##                    -1.607715                     -0.797304  
##                     dest_AVL                      dest_BDL  
##                    -1.335753                     -1.374705  
##                     dest_BGR                      dest_BHM  
##                    -1.268924                     -0.893688  
##                     dest_BNA                      dest_BOS  
##                    -1.276719                     -1.525764  
##                     dest_BQN                      dest_BTV  
##                    -1.524842                     -1.573955  
##                     dest_BUF                      dest_BUR  
##                    -1.474470                      0.128706  
##                     dest_BWI                      dest_BZN  
##                    -1.678767                     -2.082842  
##                     dest_CAE                      dest_CAK  
##                    -2.001807                     -1.620195  
##                     dest_CHO                      dest_CHS  
##                    -0.634494                     -1.480732  
##                     dest_CLE                      dest_CLT  
##                    -1.387610                     -1.586758  
## 
## ...
## and 116 more lines.

这个对象里面有最终的配方和拟合的模型对象。

可以使用辅助函数extract_fit_parsnip()和extract_recipe() 来抽取模型结果或者配方。我们查看模型结果

flights_fit %>% 
  extract_fit_parsnip() %>% 
  tidy()
## # A tibble: 157 × 5
##    term                         estimate std.error statistic  p.value
##    <chr>                           <dbl>     <dbl>     <dbl>    <dbl>
##  1 (Intercept)                   7.28    2.73           2.67 7.64e- 3
##  2 dep_time                     -0.00166 0.0000141   -118.   0       
##  3 air_time                     -0.0440  0.000563     -78.2  0       
##  4 distance                      0.00507 0.00150        3.38 7.32e- 4
##  5 date_USChristmasDay           1.33    0.177          7.49 6.93e-14
##  6 date_USColumbusDay            0.724   0.170          4.25 2.13e- 5
##  7 date_USCPulaskisBirthday      0.807   0.139          5.80 6.57e- 9
##  8 date_USDecorationMemorialDay  0.585   0.117          4.98 6.32e- 7
##  9 date_USElectionDay            0.948   0.190          4.98 6.25e- 7
## 10 date_USGoodFriday             1.25    0.167          7.45 9.40e-14
## # … with 147 more rows

进行预测

使用predict()或augment()进行预测

flights_aug <- 
  augment(flights_fit, test_data)
predict(flights_fit, test_data,type = "prob")
## # A tibble: 81,455 × 2
##    .pred_late .pred_on_time
##         <dbl>         <dbl>
##  1     0.0547         0.945
##  2     0.0515         0.949
##  3     0.0361         0.964
##  4     0.0386         0.961
##  5     0.0384         0.962
##  6     0.0249         0.975
##  7     0.0366         0.963
##  8     0.0191         0.981
##  9     0.0646         0.935
## 10     0.0687         0.931
## # … with 81,445 more rows

ROC

使用yardstick包

flights_aug %>% 
  roc_curve(truth = arr_delay, .pred_late) %>% 
  ggplot2::autoplot() +theme_classic()

同样,roc_auc()估计曲线下的面积

flights_aug %>% 
  roc_auc(truth = arr_delay, .pred_late)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.764

评估模型

一旦我们训练了一个模型,我们就需要一种方法来衡量该模型对新数据的预测效果。我们来看如何基于重采样统计来表征模型性能。

首先准备相关的包。

library(tidymodels) # for the rsample package, along with the rest of tidymodels

# Helper packages
library(modeldata)  # for the cells data

我们这里使用的是细胞图像数据。

data(cells, package = "modeldata")
cells
## # A tibble: 2,019 × 58
##    case  class angle_ch_1 area_ch_1 avg_inten_ch_1 avg_inten_ch_2 avg_inten_ch_3
##    <fct> <fct>      <dbl>     <int>          <dbl>          <dbl>          <dbl>
##  1 Test  PS        143.         185           15.7           4.95           9.55
##  2 Train PS        134.         819           31.9         207.            69.9 
##  3 Train WS        107.         431           28.0         116.            63.9 
##  4 Train PS         69.2        298           19.5         102.            28.2 
##  5 Test  PS          2.89       285           24.3         112.            20.5 
##  6 Test  WS         40.7        172          326.          654.           129.  
##  7 Test  WS        174.         177          260.          596.           124.  
##  8 Test  PS        180.         251           18.3           5.73          17.2 
##  9 Test  WS         18.9        495           16.1          89.5           13.7 
## 10 Test  WS        153.         384           17.7          89.9           20.4 
## # … with 2,009 more rows, and 51 more variables: avg_inten_ch_4 <dbl>,
## #   convex_hull_area_ratio_ch_1 <dbl>, convex_hull_perim_ratio_ch_1 <dbl>,
## #   diff_inten_density_ch_1 <dbl>, diff_inten_density_ch_3 <dbl>,
## #   diff_inten_density_ch_4 <dbl>, entropy_inten_ch_1 <dbl>,
## #   entropy_inten_ch_3 <dbl>, entropy_inten_ch_4 <dbl>,
## #   eq_circ_diam_ch_1 <dbl>, eq_ellipse_lwr_ch_1 <dbl>,
## #   eq_ellipse_oblate_vol_ch_1 <dbl>, eq_ellipse_prolate_vol_ch_1 <dbl>, …

有 58 个变量。我们在这里感兴趣的主要结果变量称为class,class 中ps表示细胞图片分割不良,ws表示细胞图片分割良好。

cells %>% 
  count(class) %>% 
  mutate(prop = n/sum(n))
## # A tibble: 2 × 3
##   class     n  prop
##   <fct> <int> <dbl>
## 1 PS     1300 0.644
## 2 WS      719 0.356

数据中的标签是不平衡的。

数据拆分

set.seed(123)
cell_split <- initial_split(cells %>% select(-case), 
                            strata = class)

strata参数用于保证训练集和测试集的标签比例是一致的。

cell_train <- training(cell_split)
cell_test  <- testing(cell_split)

nrow(cell_train)
## [1] 1514
nrow(cell_train)/nrow(cells)
## [1] 0.7498762
cell_train %>% 
  count(class) %>% 
  mutate(prop = n/sum(n))
## # A tibble: 2 × 3
##   class     n  prop
##   <fct> <int> <dbl>
## 1 PS      975 0.644
## 2 WS      539 0.356
cell_test %>% 
  count(class) %>% 
  mutate(prop = n/sum(n))
## # A tibble: 2 × 3
##   class     n  prop
##   <fct> <int> <dbl>
## 1 PS      325 0.644
## 2 WS      180 0.356

建模

构建随机森林模型。

数据拆分好了之后进行建模。我们首先定义我们想要创建的模型:

rf_mod <- 
  rand_forest(trees = 1000) %>% 
  set_engine("ranger") %>% 
  set_mode("classification")

准备拟合模型

set.seed(234)
rf_fit <- 
  rf_mod %>% 
  fit(class ~ ., data = cell_train)
rf_fit
## parsnip model object
## 
## Ranger result
## 
## Call:
##  ranger::ranger(x = maybe_data_frame(x), y = y, num.trees = ~1000,      num.threads = 1, verbose = FALSE, seed = sample.int(10^5,          1), probability = TRUE) 
## 
## Type:                             Probability estimation 
## Number of trees:                  1000 
## Sample size:                      1514 
## Number of independent variables:  56 
## Mtry:                             7 
## Target node size:                 10 
## Variable importance mode:         none 
## Splitrule:                        gini 
## OOB prediction error (Brier s.):  0.1189338

评估模型

查看训练集的预测结果

rf_training_pred <- 
  predict(rf_fit, cell_train) %>% 
  bind_cols(predict(rf_fit, cell_train, type = "prob")) %>% 
  # Add the true outcome data back in
  bind_cols(cell_train %>% 
              select(class))

rf_training_pred
## # A tibble: 1,514 × 4
##    .pred_class .pred_PS .pred_WS class
##    <fct>          <dbl>    <dbl> <fct>
##  1 PS             0.739   0.261  PS   
##  2 PS             0.934   0.0659 PS   
##  3 PS             0.929   0.0713 PS   
##  4 PS             0.961   0.0391 PS   
##  5 PS             0.922   0.0781 PS   
##  6 PS             0.711   0.289  PS   
##  7 PS             0.989   0.0112 PS   
##  8 PS             0.802   0.198  PS   
##  9 PS             0.853   0.147  PS   
## 10 PS             0.741   0.259  PS   
## # … with 1,504 more rows

计算auc

rf_training_pred %>%                # training set predictions
  roc_auc(truth = class, .pred_PS)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary          1.00

可以看到预测结果非常的好,但是这只是训练集的结果,我们进一步查看训练集的结果。

rf_testing_pred <- 
  predict(rf_fit, cell_test) %>% 
  bind_cols(predict(rf_fit, cell_test, type = "prob")) %>% 
  bind_cols(cell_test %>% select(class))

rf_testing_pred
## # A tibble: 505 × 4
##    .pred_class .pred_PS .pred_WS class
##    <fct>          <dbl>    <dbl> <fct>
##  1 PS            0.886   0.114   PS   
##  2 PS            0.910   0.0903  PS   
##  3 WS            0.0864  0.914   WS   
##  4 PS            0.838   0.162   WS   
##  5 PS            0.729   0.271   PS   
##  6 WS            0.286   0.714   WS   
##  7 PS            0.939   0.0607  PS   
##  8 WS            0.0279  0.972   WS   
##  9 PS            0.993   0.00749 PS   
## 10 WS            0.305   0.695   WS   
## # … with 495 more rows

计算auc

rf_testing_pred %>%                   # test set predictions
  roc_auc(truth = class, .pred_PS)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.891

可以看到,测试集的auc并不高,为了使得结果更加稳定,我们可以进行交叉验证。

交叉验证

第一步是使用 rsample 创建一个重采样对象。rsample 中实现了几种重采样方法;可以使用以下方法创建交叉验证折叠vfold_cv():

set.seed(345)
folds <- vfold_cv(cell_train, v = 10)
folds
## #  10-fold cross-validation 
## # A tibble: 10 × 2
##    splits             id    
##    <list>             <chr> 
##  1 <split [1362/152]> Fold01
##  2 <split [1362/152]> Fold02
##  3 <split [1362/152]> Fold03
##  4 <split [1362/152]> Fold04
##  5 <split [1363/151]> Fold05
##  6 <split [1363/151]> Fold06
##  7 <split [1363/151]> Fold07
##  8 <split [1363/151]> Fold08
##  9 <split [1363/151]> Fold09
## 10 <split [1363/151]> Fold10

我们使用workflow()将随机森林模型和公式捆绑在一起的

rf_wf <- 
  workflow() %>%
  add_model(rf_mod) %>%
  add_formula(class ~ .)

set.seed(456)
rf_fit_rs <- 
  rf_wf %>% 
  fit_resamples(folds)
rf_fit_rs
## # Resampling results
## # 10-fold cross-validation 
## # A tibble: 10 × 4
##    splits             id     .metrics         .notes          
##    <list>             <chr>  <list>           <list>          
##  1 <split [1362/152]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]>
##  2 <split [1362/152]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]>
##  3 <split [1362/152]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]>
##  4 <split [1362/152]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]>
##  5 <split [1363/151]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]>
##  6 <split [1363/151]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]>
##  7 <split [1363/151]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]>
##  8 <split [1363/151]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]>
##  9 <split [1363/151]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]>
## 10 <split [1363/151]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]>

提取指标信息

collect_metrics(rf_fit_rs)
## # A tibble: 2 × 6
##   .metric  .estimator  mean     n std_err .config             
##   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy binary     0.832    10 0.00952 Preprocessor1_Model1
## 2 roc_auc  binary     0.904    10 0.00610 Preprocessor1_Model1

计算准确度和auc

rf_testing_pred %>%                   # test set predictions
  roc_auc(truth = class, .pred_PS)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.891
rf_testing_pred %>%                   # test set predictions
  accuracy(truth = class, .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.816

调整模型参数

为了进一步优化我们的模型,我们可以调整模型的参数。

library(tidymodels)  # for the tune package, along with the rest of tidymodels

# Helper packages
library(rpart.plot)  # for visualizing a decision tree
## Loading required package: rpart
## 
## Attaching package: 'rpart'
## The following object is masked from 'package:dials':
## 
##     prune
library(vip)         # for variable importance plots
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi

创建了一个模型规范来确定我们计划调整哪些超参数。

tune_spec <- 
  decision_tree(
    cost_complexity = tune(),
    tree_depth = tune()
  ) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")

tune_spec
## Decision Tree Model Specification (classification)
## 
## Main Arguments:
##   cost_complexity = tune()
##   tree_depth = tune()
## 
## Computational engine: rpart

将tune()此处视为占位符。然后,创建一个规则的值网格

tree_grid <- grid_regular(cost_complexity(),
                          tree_depth(),
                          levels = 5)

tree_grid
## # A tibble: 25 × 2
##    cost_complexity tree_depth
##              <dbl>      <int>
##  1    0.0000000001          1
##  2    0.0000000178          1
##  3    0.00000316            1
##  4    0.000562              1
##  5    0.1                   1
##  6    0.0000000001          4
##  7    0.0000000178          4
##  8    0.00000316            4
##  9    0.000562              4
## 10    0.1                   4
## # … with 15 more rows

函数grid_regular()来自dials包,这个阐述可以创建不同组合的超参数。

接着创建的重采样对象。

set.seed(234)
cell_folds <- vfold_cv(cell_train)

调整参数

tune_grid()我们为每个调整过的超参数选择的所有不同值来拟合模型。

set.seed(345)

tree_wf <- workflow() %>%
  add_model(tune_spec) %>%
  add_formula(class ~ .)

tree_res <- 
  tree_wf %>% 
  tune_grid(
    resamples = cell_folds,
    grid = tree_grid
    )

tree_res
## # Tuning results
## # 10-fold cross-validation 
## # A tibble: 10 × 4
##    splits             id     .metrics          .notes          
##    <list>             <chr>  <list>            <list>          
##  1 <split [1362/152]> Fold01 <tibble [50 × 6]> <tibble [0 × 3]>
##  2 <split [1362/152]> Fold02 <tibble [50 × 6]> <tibble [0 × 3]>
##  3 <split [1362/152]> Fold03 <tibble [50 × 6]> <tibble [0 × 3]>
##  4 <split [1362/152]> Fold04 <tibble [50 × 6]> <tibble [0 × 3]>
##  5 <split [1363/151]> Fold05 <tibble [50 × 6]> <tibble [0 × 3]>
##  6 <split [1363/151]> Fold06 <tibble [50 × 6]> <tibble [0 × 3]>
##  7 <split [1363/151]> Fold07 <tibble [50 × 6]> <tibble [0 × 3]>
##  8 <split [1363/151]> Fold08 <tibble [50 × 6]> <tibble [0 × 3]>
##  9 <split [1363/151]> Fold09 <tibble [50 × 6]> <tibble [0 × 3]>
## 10 <split [1363/151]> Fold10 <tibble [50 × 6]> <tibble [0 × 3]>

使用collect_metrics()函数查看不同模型的参数结果。

tree_res %>% 
  collect_metrics()
## # A tibble: 50 × 8
##    cost_complexity tree_depth .metric  .estimator  mean     n std_err .config   
##              <dbl>      <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>     
##  1    0.0000000001          1 accuracy binary     0.732    10  0.0148 Preproces…
##  2    0.0000000001          1 roc_auc  binary     0.777    10  0.0107 Preproces…
##  3    0.0000000178          1 accuracy binary     0.732    10  0.0148 Preproces…
##  4    0.0000000178          1 roc_auc  binary     0.777    10  0.0107 Preproces…
##  5    0.00000316            1 accuracy binary     0.732    10  0.0148 Preproces…
##  6    0.00000316            1 roc_auc  binary     0.777    10  0.0107 Preproces…
##  7    0.000562              1 accuracy binary     0.732    10  0.0148 Preproces…
##  8    0.000562              1 roc_auc  binary     0.777    10  0.0107 Preproces…
##  9    0.1                   1 accuracy binary     0.732    10  0.0148 Preproces…
## 10    0.1                   1 roc_auc  binary     0.777    10  0.0107 Preproces…
## # … with 40 more rows

使用show_best函数可以显示前五个候选模型。

tree_res %>%
  show_best("accuracy")
## # A tibble: 5 × 8
##   cost_complexity tree_depth .metric  .estimator  mean     n std_err .config    
##             <dbl>      <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>      
## 1    0.0000000001          4 accuracy binary     0.807    10  0.0119 Preprocess…
## 2    0.0000000178          4 accuracy binary     0.807    10  0.0119 Preprocess…
## 3    0.00000316            4 accuracy binary     0.807    10  0.0119 Preprocess…
## 4    0.000562              4 accuracy binary     0.807    10  0.0119 Preprocess…
## 5    0.1                   4 accuracy binary     0.786    10  0.0124 Preprocess…

使用select_best()函数为我们的最佳决策树模型提取一组超参数值。

best_tree <- tree_res %>%
  select_best("accuracy")

best_tree
## # A tibble: 1 × 3
##   cost_complexity tree_depth .config              
##             <dbl>      <int> <chr>                
## 1    0.0000000001          4 Preprocessor1_Model06

选择最优参数构建模型,代码如下所示。

final_wf <- 
  tree_wf %>% 
  finalize_workflow(best_tree)

final_wf
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: decision_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## class ~ .
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Decision Tree Model Specification (classification)
## 
## Main Arguments:
##   cost_complexity = 1e-10
##   tree_depth = 4
## 
## Computational engine: rpart
final_fit <- 
  final_wf %>%
  last_fit(cell_split) 

final_fit %>%
  collect_metrics()
## # A tibble: 2 × 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy binary         0.802 Preprocessor1_Model1
## 2 roc_auc  binary         0.840 Preprocessor1_Model1
final_fit %>%
  collect_predictions() %>% 
  roc_curve(class, .pred_PS) %>% 
  autoplot()

final_fit对象包含一个最终的、适合的工作流程,可以使用它来预测新数据或进一步了解结果。

final_tree <- extract_workflow(final_fit)
final_tree
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: decision_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## class ~ .
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## n= 1514 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 1514 539 PS (0.64398943 0.35601057)  
##    2) total_inten_ch_2< 41732.5 642  33 PS (0.94859813 0.05140187)  
##      4) shape_p_2_a_ch_1>=1.251801 631  27 PS (0.95721078 0.04278922) *
##      5) shape_p_2_a_ch_1< 1.251801 11   5 WS (0.45454545 0.54545455) *
##    3) total_inten_ch_2>=41732.5 872 366 WS (0.41972477 0.58027523)  
##      6) fiber_width_ch_1< 11.37318 406 160 PS (0.60591133 0.39408867)  
##       12) avg_inten_ch_1< 145.4883 293  85 PS (0.70989761 0.29010239) *
##       13) avg_inten_ch_1>=145.4883 113  38 WS (0.33628319 0.66371681)  
##         26) total_inten_ch_3>=57919.5 33  10 PS (0.69696970 0.30303030) *
##         27) total_inten_ch_3< 57919.5 80  15 WS (0.18750000 0.81250000) *
##      7) fiber_width_ch_1>=11.37318 466 120 WS (0.25751073 0.74248927)  
##       14) eq_ellipse_oblate_vol_ch_1>=1673.942 30   8 PS (0.73333333 0.26666667)  
##         28) var_inten_ch_3>=41.10858 20   2 PS (0.90000000 0.10000000) *
##         29) var_inten_ch_3< 41.10858 10   4 WS (0.40000000 0.60000000) *
##       15) eq_ellipse_oblate_vol_ch_1< 1673.942 436  98 WS (0.22477064 0.77522936) *

对模型结果进行数据可视化。

final_tree %>%
  extract_fit_engine() %>%
  rpart.plot(roundint = FALSE)

分析变量的重要性。

library(vip)

final_tree %>% 
  extract_fit_parsnip() %>% 
  vip()

rsample

rsample 是与抽样相关的一个包。

重采样方法

  1. initial_split() initial_time_split() training() testing()简单的训练/测试集拆分
  2. bootstraps() 自举抽样
  3. vfold_cv() V 折交叉验证
  4. loo_cv() 留一法交叉验证
  5. mc_cv() 蒙特卡洛交叉验证
  6. validation_split() validation_time_split() 创建验证集
  7. group_vfold_cv() 组 V 折交叉验证
  8. rolling_origin() 滚动原点预测重采样
  9. sliding_window() sliding_index() sliding_period() 基于时间的重采样
  10. sliding_window() sliding_index() sliding_period() 基于时间的重采样
  11. sliding_window() sliding_index() sliding_period() 基于时间的重采样
  12. nested_cv() 嵌套或双重重采样
  13. apparent() 表观错误率抽样
  14. manual_rset() 手动重采样
  15. permutations() 置换抽样

分析相关函数

  1. int_pctl() int_t() int_bca() 自举置信区间
  2. reg_intervals() 具有线性参数模型的置信区间的便利函数

实用函数

1.get_fingerprint() 获取重采样的标识符 2. as.data.frame() analysis() assessment() 将rsplit对象转换为数据框 3. add_resample_id() 使用重采样标识符扩充数据集 4. complement() 确定评估样本 5. form_pred() 从公式或术语中提取预测变量名称 6. labels() labels() 从 rset 对象中查找标签 7. labels() 从 rsplit 对象中查找标签 8. make_splits() 拆分对象的构造函数 9. make_strata() 创建或修改分层变量 10. populate() 添加评估指数 11. rsample2caret() caret2rsample() 将重采样对象转换为其他格式 12. rset_reconstruct() 使用新的 rset 子类扩展 rsample 13. tidy() tidy() tidy() tidy() 整洁的重采样对象

parsnip

parsnip 的目标是为模型提供一个整洁、统一的接口,可用于尝试一系列模型,而不会陷入底层包的语法细节。

在R中有很多的包实现了相同的算法,例如构建随机森林模型。

# From randomForest
rf_1 <- randomForest(
  y ~ ., 
  data = ., 
  mtry = 10, 
  ntree = 2000, 
  importance = TRUE
)

# From ranger
rf_2 <- ranger(
  y ~ ., 
  data = dat, 
  mtry = 10, 
  num.trees = 2000, 
  importance = "impurity"
)

# From sparklyr
rf_3 <- ml_random_forest(
  dat, 
  intercept = FALSE, 
  response = "y", 
  features = names(dat)[names(dat) != "y"], 
  col.sample.rate = 10,
  num.trees = 2000
)

可以看到,不同包的模型语法是不一样的。而使用parsnip包,则可以用统一的结构构建随机森林模型。


# reanger包
library(parsnip)

rand_forest(mtry = 10, trees = 2000) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("regression")
  
# spark  
rand_forest(mtry = 10, trees = 2000) %>%
  set_engine("spark") %>%
  set_mode("regression")

模型函数

  1. bag_mars() MARS 模型的集合
  2. bag_tree() 决策树的集合
  3. bart() 贝叶斯加性回归树 (BART)
  4. boost_tree() 增强树
  5. cubist_rules() Cubist rule-based回归模型
  6. C5_rules() C5.0 基于规则的分类模型
  7. decision_tree() 决策树
  8. discrim_flexible() 灵活的判别分析
  9. discrim_linear() 线性判别分析
  10. discrim_quad() 二次判别分析
  11. discrim_regularized() 正则化判别分析
  12. gen_additive_mod() 广义加法模型 (GAM)
  13. linear_reg() 线性回归
  14. logistic_reg() 逻辑回归
  15. mars() 多元自适应回归样条 (MARS)
  16. mlp() 单层神经网络
  17. multinom_reg() 多项回归
  18. naive_Bayes() 朴素贝叶斯模型
  19. nearest_neighbor() K-最近邻
  20. null_model() 空模型
  21. pls() 偏最小二乘法 (PLS)
  22. poisson_reg() 泊松回归模型
  23. proportional_hazards() 比例风险回归
  24. rand_forest() 随机森林
  25. rule_fit() RuleFit 模型
  26. survival_reg() 生存回归
  27. svm_linear() 线性支持向量机
  28. svm_poly() 多项式支持向量机
  29. svm_rbf() 径向基函数支持向量机

查看帮助则可以查看可以设置哪些engine

基础函数

  1. autoplot() autoplot() 为模型对象创建 ggplot
  2. add_rowindex() 将一列行号添加到数据框中
  3. augment() 使用预测增强数据
  4. .cols() .preds() .obs() .lvls() .facts() .x() .y() .dat() 拟合模型时可用的数据集特征
  5. extract_spec_parsnip() extract_fit_engine() extract_parameter_set_dials() extract_parameter_dials() 提取parsnip模型对象的元素
  6. fit() fit_xy() 将模型规范拟合到数据集
  7. reexports 从其他包导出的对象
  8. control_parsnip() 控制拟合函数
  9. glance() 构建模型、拟合或其他对象的单行摘要“概览”
  10. model_fit 模型拟合对象信息
  11. model_spec 型号规格信息
  12. multi_predict 跨多个子模型的模型预测
  13. parsnip_addin() 启动一个可以编写模型规范的 RStudio Addin
  14. predict() predict_raw() 模型预测
  15. repair_call() 修复模型调用对象
  16. set_args() set_mode() 更改模型规范的元素
  17. set_engine() 声明计算引擎和具体参数
  18. show_engines() 显示模型当前可用的引擎
  19. tidy() 将parsnip模型对象变成整洁的 tibble

recipes

使用recipes ,可以创建一系列的数据处理步骤。

library(recipes)
data(ad_data, package = "modeldata")

ad_rec <- recipe(Class ~ tau + VEGF, data = ad_data) %>%
  step_normalize(all_numeric_predictors())

基本功能

  1. recipe() 创建用于预处理数据的配方
  2. formula() 从准备好的配方创建公式
  3. print() 打印recipes
  4. summary() summary recipes
  5. prep() 估计预处理recipes
  6. bake() 应用经过训练的预处理recipes
  7. juice() 提取转换后的训练集
  8. selections 在step函数中选择变量的方法
  9. has_role() all_predictors() all_numeric_predictors() all_nominal_predictors() all_outcomes() has_type() all_numeric() all_nominal() current_info() 角色选择函数
  10. add_role() update_role() remove_role() 手动更改角色

step函数 - 插补

1.step_impute_bag() step_bagimpute() imp_vars() 通过袋装树估算 2. step_impute_knn() step_knnimpute() 通过 k-最近邻进行估算 3. step_impute_linear() 通过线性模型估算数值变量 4. step_impute_lower() step_lowerimpute() 估算低于测量阈值的数值数据 5. step_impute_mean() step_meanimpute() 使用平均值估算数值数据 6. step_impute_median() step_medianimpute() 使用中位数估算数值数据 7. step_impute_mode() step_modeimpute() 使用最常见的值估算名义数据 8. step_impute_roll() step_rollimpute() 使用滚动窗口统计量估算数值数据 9. step_unknown() 将缺失的类别分配给“未知”

Step Functions -单变量转换

  1. step_BoxCox() 非负数据的 Box-Cox 变换
  2. step_bs() B样条基函数
  3. step_harmonic() 为谐波分析添加 sin 和 cos 项
  4. step_hyperbolic() 双曲变换
  5. step_inverse() 逆变换
  6. step_invlogit() 逆 Logit 变换
  7. step_log() 对数变换
  8. step_logit() logit变换
  9. step_mutate() 使用 dplyr 添加新变量
  10. step_ns() 自然样条基函数
  11. step_poly() 正交多项式基函数
  12. step_relu() 应用(平滑)校正线性变换
  13. step_sqrt() 平方根变换
  14. step_YeoJohnson() Yeo-Johnson 转换

step函数 - 离散化

  1. step_discretize() 离散数值变量
  2. discretize() predict() 离散数值变量
  3. step_cut() 将数值变量切割成因子

step函数 - 虚拟变量和编码

  1. step_bin2factor() 从虚拟变量创建因子
  2. step_count() 使用正则表达式创建模式计数
  3. step_date() 日期特征生成器
  4. step_dummy() 创建传统的虚拟变量
  5. step_dummy_extract() 从名义数据中提取模式
  6. step_dummy_multi_choice() 一起处理多个预测变量中的级别
  7. step_factor2string() 将因子转换为字符串
  8. step_holiday() 假日特征生成器
  9. step_indicate_na() 创建缺失数据列指示器
  10. step_integer() 将值转换为预定义的整数
  11. step_novel() 新因子水平的简单赋值
  12. step_num2factor() 将数字转换为因子
  13. step_ordinalscore() 将序数因子转换为数字分数
  14. step_other() 折叠一些分类级别
  15. step_percentile() 百分位变换
  16. step_regex() 检测正则表达式
  17. step_relevel() 将因子重新调整到所需水平
  18. step_string2factor() 将字符串转换为因子
  19. step_unknown() 将缺失的类别分配给“未知”
  20. step_unorder() 将有序因子转换为无序因子

Step Functions - 交互

  1. step_interact() 创建交互变量

step 函数 - 归一化

  1. step_center() 居中数值数据
  2. step_normalize() 居中并缩放数值数据
  3. step_range() 将数值数据缩放到特定范围
  4. step_scale() 缩放数值数据

step函数 - 多元变换

  1. step_classdist() 到类质心的距离
  2. step_depth() 数据深度
  3. step_geodist() 两个位置之间的距离
  4. step_ica() ICA信号提取
  5. step_isomap() Isomap 嵌入
  6. step_kpca() 内核 PCA 信号提取
  7. step_kpca_poly() 多项式核 PCA 信号提取
  8. step_kpca_rbf() 径向基函数内核 PCA 信号提取
  9. step_mutate_at() 使用 dplyr 改变多个列
  10. step_nnmf() 非负矩阵分解信号提取
  11. step_nnmf_sparse() 带套索惩罚的非负矩阵分解信号提取
  12. step_pca() PCA 信号提取
  13. step_pls() 偏最小二乘特征提取
  14. step_ratio() denom_vars() 比率变量创建
  15. step_spatialsign() 空间符号预处理

Step Functions - 过滤器

  1. step_corr() 高相关过滤器
  2. step_filter_missing() 缺失值列过滤器
  3. step_lincomb() 线性组合滤波器
  4. step_nzv() 近零方差滤波器
  5. step_rm() 通用变量过滤器
  6. step_select() 使用 dplyr 选择变量
  7. step_zv() 零方差滤波器
  8. Step Functions - 行操作
  9. step_arrange() 使用 dplyr 对行进行排序
  10. step_filter() 使用 dplyr 过滤行
  11. step_lag() 创建滞后预测器
  12. step_naomit() 删除具有缺失值的观测值
  13. step_impute_roll() step_rollimpute() 使用滚动窗口统计量估算数值数据
  14. step_sample() 使用 dplyr 的示例行
  15. step_shuffle() 随机变量
  16. step_slice() 使用 dplyr 按位置过滤行

Step Functions - 其他

  1. step_intercept() 添加截距(或常数)列
  2. step_profile() 创建数据集的分析版本
  3. step_rename() 使用 dplyr 按名称重命名变量
  4. step_rename_at() 使用 dplyr 重命名多个列
  5. step_window() 移动窗函数

检查函数

  1. check_class() 检查变量类
  2. check_cols() 检查是否所有列都存在
  3. check_missing() 检查缺失值
  4. check_new_values() 检查新值
  5. check_range() 检查范围一致性

内部处理函数

  1. add_step() add_check() 向当前配方添加新操作
  2. detect_step() 检测配方中是否使用了特定步骤或检查
  3. fully_trained() 检查食谱是否经过培训/准备
  4. names0() dummy_names() dummy_extract_names() 命名函数
  5. prepper() 用于在重采样中准备配方的包装函数
  6. recipes_eval_select() 使用特定于食谱的 tidyselect 语义评估选择 update()

workflow

工作流是一个对象,可以将预处理、建模和后处理请求捆绑在一起。

指定样条曲线

library(recipes)
library(parsnip)
library(workflows)

spline_cars <- recipe(mpg ~ ., data = mtcars) %>% 
  step_ns(disp, deg_free = 10)

指定模型对象

bayes_lm <- linear_reg() %>% 
  set_engine("stan")

拟合模型

spline_cars_prepped <- prep(spline_cars, mtcars)
bayes_lm_fit <- fit(bayes_lm, mpg ~ ., data = juice(spline_cars_prepped))

使用workflow 则可以让代码更加的简洁

car_wflow <- workflow() %>% 
  add_recipe(spline_cars) %>% 
  add_model(bayes_lm)

car_wflow_fit <- fit(car_wflow, data = mtcars)

常用函数

  1. add_formula() remove_formula() update_formula() 将公式添加到工作流
  2. add_model() remove_model() update_model() 将模型添加到工作流
  3. add_recipe() remove_recipe() update_recipe() 将配recipes添加到工作流程
  4. add_variables() remove_variables() update_variables() workflow_variables() 将变量添加到工作流
  5. augment() 使用预测增强数据
  6. control_workflow() 工作流的控制对象
  7. extract_* 函数 提取工作流的元素
  8. fit() 适合工作流对象
  9. glance() 浏览工作流模型
  10. is_trained_workflow() 确定工作流是否已经过训练
  11. predict() 从工作流预测
  12. tidy() tidy工作流程
  13. workflow() 创建工作流

tune

这个包用于调整模型的超参数

常用函数

  1. tune_bayes() 模型参数的贝叶斯优化。
  2. tune_grid() 通过网格搜索进行模型调整
  3. fit_resamples() 通过重采样拟合多个模型
  4. last_fit() 将最终的最佳模型拟合到训练集并评估测试集
  5. prob_improve() exp_improve() conf_bound() 评分参数组合的采集功能
  6. control_bayes() 贝叶斯搜索过程的控制方面
  7. control_grid() control_resamples() 控制网格搜索过程的各个方面
  8. autoplot() 绘制调优搜索结果
  9. augment() augment() augment() 扩展数据
  10. coord_obs_pred() 对观察值与预测值的图使用相同的比例
  11. expo_decay() 指数衰减函数
  12. collect_predictions() collect_metrics() collect_notes() 获取并格式化调整函数产生的结果
  13. filter_parameters() 移除一些调优参数结果
  14. show_best() select_best() select_by_pct_loss() select_by_one_std_err() 调查最佳调整参数
  15. extract_*函数 提取tune对象的元素
  16. extract_model() 提取模型的便捷函数
  17. finalize_model() finalize_recipe() finalize_workflow() 将最终参数拼接到对象中
  18. conf_mat_resampled() 计算重采样的平均混淆矩阵

yardstick

这个包是与模型评估相关的一个包。

二分类

假设我们有这样一份数据

library(yardstick)
library(dplyr)

head(two_class_example)
##    truth      Class1       Class2 predicted
## 1 Class2 0.003589243 0.9964107574    Class2
## 2 Class1 0.678621054 0.3213789460    Class1
## 3 Class2 0.110893522 0.8891064779    Class2
## 4 Class1 0.735161703 0.2648382969    Class1
## 5 Class2 0.016239960 0.9837600397    Class2
## 6 Class1 0.999275071 0.0007249286    Class1

计算指标

metrics(two_class_example, truth, predicted)
## # A tibble: 2 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.838
## 2 kap      binary         0.675
two_class_example %>% 
  roc_auc(truth, Class1)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.939

多分类

data("hpc_cv")
hpc_cv <- as_tibble(hpc_cv)
hpc_cv
## # A tibble: 3,467 × 7
##    obs   pred     VF      F       M          L Resample
##    <fct> <fct> <dbl>  <dbl>   <dbl>      <dbl> <chr>   
##  1 VF    VF    0.914 0.0779 0.00848 0.0000199  Fold01  
##  2 VF    VF    0.938 0.0571 0.00482 0.0000101  Fold01  
##  3 VF    VF    0.947 0.0495 0.00316 0.00000500 Fold01  
##  4 VF    VF    0.929 0.0653 0.00579 0.0000156  Fold01  
##  5 VF    VF    0.942 0.0543 0.00381 0.00000729 Fold01  
##  6 VF    VF    0.951 0.0462 0.00272 0.00000384 Fold01  
##  7 VF    VF    0.914 0.0782 0.00767 0.0000354  Fold01  
##  8 VF    VF    0.918 0.0744 0.00726 0.0000157  Fold01  
##  9 VF    VF    0.843 0.128  0.0296  0.000192   Fold01  
## 10 VF    VF    0.920 0.0728 0.00703 0.0000147  Fold01  
## # … with 3,457 more rows
precision(hpc_cv, obs, pred)
## # A tibble: 1 × 3
##   .metric   .estimator .estimate
##   <chr>     <chr>          <dbl>
## 1 precision macro          0.631
precision(hpc_cv, obs, pred, estimator = "micro")
## # A tibble: 1 × 3
##   .metric   .estimator .estimate
##   <chr>     <chr>          <dbl>
## 1 precision micro          0.709

计算重采样指标

对模型进行了多次重采样,则可以使用分组数据框上的指标来一次计算所有重采样的指标。

hpc_cv %>%
  group_by(Resample) %>%
  roc_auc(obs, VF:L)
## # A tibble: 10 × 4
##    Resample .metric .estimator .estimate
##    <chr>    <chr>   <chr>          <dbl>
##  1 Fold01   roc_auc hand_till      0.813
##  2 Fold02   roc_auc hand_till      0.817
##  3 Fold03   roc_auc hand_till      0.869
##  4 Fold04   roc_auc hand_till      0.849
##  5 Fold05   roc_auc hand_till      0.811
##  6 Fold06   roc_auc hand_till      0.836
##  7 Fold07   roc_auc hand_till      0.825
##  8 Fold08   roc_auc hand_till      0.846
##  9 Fold09   roc_auc hand_till      0.828
## 10 Fold10   roc_auc hand_till      0.812

对结果进行可视化

基于曲线的方法,例如roc_curve(),pr_curve()和gain_curve()都可以使用ggplot2::autoplot()进行数据可视化。

library(ggplot2)

hpc_cv %>%
  group_by(Resample) %>%
  roc_curve(obs, VF:L) %>%
  autoplot()

分类指标

  1. sens() sens_vec() sensitivity() sensitivity_vec() 灵敏度
  2. spec() spec_vec() specificity() specificity_vec() 特异性
  3. recall() recall_vec() 召回
  4. precision() precision_vec() 精度
  5. mcc() mcc_vec() 马修斯相关系数
  6. j_index() j_index_vec() J指数
  7. f_meas() f_meas_vec() F 值
  8. accuracy() accuracy_vec() 准确性
  9. kap() kap_vec() 卡帕
  10. ppv() ppv_vec() 阳性预测值
  11. npv() npv_vec() 负预测值
  12. bal_accuracy() bal_accuracy_vec() 平衡精度
  13. detection_prevalence() detection_prevalence_vec() 检出率

分类概率度量

  1. roc_auc() roc_auc_vec() AUC
  2. roc_aunp() roc_aunp_vec() 每个类别相对于其他类别的 ROC 曲线下面积,使用先验类别分布
  3. roc_aunu() roc_aunu_vec() 每个类别相对于其他类别的 ROC 曲线下面积,使用均匀类别分布
  4. pr_auc() pr_auc_vec() 精确召回曲线下的面积
  5. average_precision() average_precision_vec() 精确召回曲线下的面积
  6. gain_capture() gain_capture_vec() gain
  7. mn_log_loss() mn_log_loss_vec() 多项数据的平均对数损失
  8. classification_cost() classification_cost_vec() 分类不良的成本函数

回归指标

  1. rmse() rmse_vec() 均方根误差
  2. rsq() rsq_vec() R平方
  3. rsq_trad() rsq_trad_vec() R 平方
  4. msd() msd_vec() 平均符号偏差
  5. mae() mae_vec() 平均绝对误差
  6. mpe() mpe_vec() 平均百分比误差
  7. mape() mape_vec() 平均绝对百分比误差
  8. smape() smape_vec() 对称平均绝对百分比误差
  9. mase() mase_vec() 平均绝对比例误差 10.ccc() ccc_vec() 一致性相关系数
  10. rpiq() rpiq_vec()
  11. 绩效与四分位数的比率
  12. rpd() rpd_vec() 绩效与偏差的比率
  13. huber_loss() huber_loss_vec() 胡贝尔损失
  14. huber_loss_pseudo() huber_loss_pseudo_vec() 伪 Huber 损失
  15. iic() iic_vec() 相关系数
  16. poisson_log_loss() poisson_log_loss_vec() 泊松数据的平均对数损失

曲线函数

1.roc_curve() 接受者操作曲线 2. pr_curve() 精确召回曲线 3. gain_curve() 增益曲线 4. lift_curve() 提升曲线

其他功能

  1. metrics() 估计性能的一般函数
  2. metric_set() 组合度量函数
  3. metric_tweak() 调整度量函数
  4. conf_mat() tidy() 分类数据的混淆矩阵
  5. summary() 混淆矩阵的汇总统计

broom

broom包能够将模型结果格式化。broom包有三个关键函数,分别是:

  1. tidy
  2. glance
  3. augment

我们来看一个简单的例子:

library(broom)

fit <- lm(Petal.Width ~ ., iris[,-5])

tidy(fit) # 将模型结果标准化输出
## # A tibble: 4 × 5
##   term         estimate std.error statistic  p.value
##   <chr>           <dbl>     <dbl>     <dbl>    <dbl>
## 1 (Intercept)    -0.240    0.178      -1.35 1.80e- 1
## 2 Sepal.Length   -0.207    0.0475     -4.36 2.41e- 5
## 3 Sepal.Width     0.223    0.0489      4.55 1.10e- 5
## 4 Petal.Length    0.524    0.0245     21.4  7.33e-47
glance(fit) # 报告有关整个模型的信息
## # A tibble: 1 × 12
##   r.squared adj.r.squared sigma statistic  p.value    df logLik   AIC   BIC
##       <dbl>         <dbl> <dbl>     <dbl>    <dbl> <dbl>  <dbl> <dbl> <dbl>
## 1     0.938         0.937 0.192      734. 7.83e-88     3   36.8 -63.5 -48.4
## # … with 3 more variables: deviance <dbl>, df.residual <int>, nobs <int>
augment(fit, data = iris) # 向数据集中添加列
## # A tibble: 150 × 11
##    Sepal.Length Sepal.Width Petal.Length Petal.Width Species .fitted    .resid
##           <dbl>       <dbl>        <dbl>       <dbl> <fct>     <dbl>     <dbl>
##  1          5.1         3.5          1.4         0.2 setosa    0.216 -0.0163  
##  2          4.9         3            1.4         0.2 setosa    0.146  0.0537  
##  3          4.7         3.2          1.3         0.2 setosa    0.180  0.0201  
##  4          4.6         3.1          1.5         0.2 setosa    0.283 -0.0832  
##  5          5           3.6          1.4         0.2 setosa    0.259 -0.0593  
##  6          5.4         3.9          1.7         0.4 setosa    0.400 -0.000428
##  7          4.6         3.4          1.4         0.3 setosa    0.298  0.00240 
##  8          5           3.4          1.5         0.2 setosa    0.267 -0.0671  
##  9          4.4         2.9          1.4         0.2 setosa    0.228 -0.0276  
## 10          4.9         3.1          1.5         0.1 setosa    0.221 -0.121   
## # … with 140 more rows, and 4 more variables: .hat <dbl>, .sigma <dbl>,
## #   .cooksd <dbl>, .std.resid <dbl>

dials

这个包用于管理和调整tidymodels包的参数调整。

参数集

  1. parameters() 对象内调整参数的信息
  2. update() 更新参数集中的单个参数
  3. pull_dials_object() 返回与参数关联的 dials 参数对象
  4. range_validate() range_get() range_set() 处理参数范围的工具
  5. value_validate() value_seq() value_sample() value_transform() value_inverse() value_set() 处理参数值的工具

创建网格

  1. grid_max_entropy() grid_latin_hypercube() 空间填充参数网格
  2. grid_regular() grid_random() 创建调整参数网格

参数对象

  1. activation() values_activation 网络层之间的激活函数
  2. adjust_deg_free() 调整有效自由度的参数
  3. all_neighbors() 用于确定要使用哪些邻居的参数
  4. prior_terminal_node_coef() prior_terminal_node_expo() prior_outcome_range() BART 模型的参数 这些参数用于构建贝叶斯自适应回归树 (BART) 模型。
  5. class_weights() 不平衡问题的类权重参数
  6. conditional_min_criterion() values_test_type conditional_test_type() values_test_statistic conditional_test_statistic() Parameters for possible engine parameters for party models
  7. confidence_factor() no_global_pruning() predictor_winnowing() fuzzy_thresholding() rule_bands() C5.0 可能参数
  8. cost() svm_margin() 支持向量机参数
  9. deg_free() 自由度(整数)
  10. degree() degree_int() spline_degree() prod_degree() 指数参数
  11. dist_power() Minkowski 距离参数
  12. dropout() epochs() hidden_units() batch_size() 神经网络参数
  13. extrapolation() unbiased_rules() max_rules() Cubist 可能的引擎参数的参数
  14. freq_cut() unique_cut() 近零方差参数
  15. Laplace() 拉普拉斯校正参数
  16. learn_rate() 学习率
  17. max_nodes() randomForest 的可能的参数
  18. max_num_terms() earth模型可能参数
  19. max_times() min_times() 删除词频
  20. max_tokens() 最大保留令牌数
  21. min_dist() 嵌入点之间有效最小距离的参数
  22. min_unique() 预处理的唯一值数量
  23. mixture() 混合处罚条款
  24. momentum() 梯度下降动量参数
  25. mtry() mtry_long() 随机抽样的预测变量数
  26. neighbors() 邻居数
  27. num_breaks() 分箱的切割点数
  28. num_comp() num_terms() 新特征的数量
  29. num_hash() signed_hash() 文本哈希参数
  30. num_knots() Number of knots
  31. num_tokens() 用于确定 ngram 中标记数量的参数
  32. over_ratio() under_ratio() 类不平衡抽样的参数
  33. penalty() 正则化/惩罚量
  34. predictor_prop() 预测变量的比例
  35. prior_slab_dispersion() prior_mixture_threshold() 贝叶斯 PCA 参数
  36. prune_method() values_prune_method MARS修剪方法
  37. rbf_sigma() scale_factor() kernel_offset() 内核参数 38.regularization_factor() regularize_depth() significance_threshold() lower_quantile() splitting_rule() ranger_class_rules ranger_reg_rules ranger_split_rules num_random_splits() ranger可能的参数
  38. regularization_method() values_regularization_method 正则化模型的估计方法 40.scale_pos_weight() penalty_L2() penalty_L1() xgboost 可能的参数
  39. select_features() 启用特征选择的参数
  40. shrinkage_correlation() shrinkage_variance() shrinkage_frequencies() diagonal_covariance() sda 模型参数
  41. smoothness() 内核平滑度
  42. stop_iter() 提前停止参数
  43. summary_stat() values_summary_stat 移动窗口的滚动汇总统计
  44. surv_dist() values_surv_dist 删失数据的参数分布
  45. survival_link() values_survival_link 生存模型链接函数
  46. threshold() 一般阈值参数
  47. token() values_token 令牌类型
  48. trees() min_n() sample_size() sample_prop() loss_reduction() tree_depth() prune() cost_complexity() 与基于树和基于规则的模型相关的参数函数。 50.vocabulary_size() 词汇表中的记号数
  49. weight() “double normalization”创建令牌计数时的参数
  50. weight_func() values_weight_func 距离加权的核函数
  51. weight_scheme() values_weight_scheme 词频加权方法
  52. window_size() 移动窗口大小的参数
  53. 最终确定参数 finalize() get_p() get_log_p() get_n_frac() get_n_frac_range() get_n() get_rbf_range() get_batch_sizes() 最终确定数据特定参数范围的函数

infer

该包适用于统计推断的R包。有四个主要函数。

  1. specify() 允许指定感兴趣的变量或变量之间的关系。
  2. hypothesize() 允许声明原假设。
  3. generate() 允许生成反映零假设的数据。
  4. calculate() 允许根据生成的数据计算统计分布以形成空分布。

简单的例子

我们来看一个简单的例子。

data(gss)

str(gss)
## tibble [500 × 11] (S3: tbl_df/tbl/data.frame)
##  $ year   : num [1:500] 2014 1994 1998 1996 1994 ...
##  $ age    : num [1:500] 36 34 24 42 31 32 48 36 30 33 ...
##  $ sex    : Factor w/ 2 levels "male","female": 1 2 1 1 1 2 2 2 2 2 ...
##  $ college: Factor w/ 2 levels "no degree","degree": 2 1 2 1 2 1 1 2 2 1 ...
##  $ partyid: Factor w/ 5 levels "dem","ind","rep",..: 2 3 2 2 3 3 1 2 3 1 ...
##  $ hompop : num [1:500] 3 4 1 4 2 4 2 1 5 2 ...
##  $ hours  : num [1:500] 50 31 40 40 40 53 32 20 40 40 ...
##  $ income : Ord.factor w/ 12 levels "lt $1000"<"$1000 to 2999"<..: 12 11 12 12 12 12 12 12 12 10 ...
##  $ class  : Factor w/ 6 levels "lower class",..: 3 2 2 2 3 3 2 3 3 2 ...
##  $ finrela: Factor w/ 6 levels "far below average",..: 2 2 2 4 4 3 2 4 3 1 ...
##  $ weight : num [1:500] 0.896 1.083 0.55 1.086 1.083 ...

我们对age和partyid进行方差分析

F_hat <- gss %>% 
  specify(age ~ partyid) %>%
  calculate(stat = "F")
## Dropping unused factor levels DK from the supplied explanatory variable 'partyid'.

生成零分布

null_dist <- gss %>%
   specify(age ~ partyid) %>%
   hypothesize(null = "independence") %>%
   generate(reps = 1000, type = "permute") %>%
   calculate(stat = "F")
## Dropping unused factor levels DK from the supplied explanatory variable 'partyid'.

将观察到的统计数据与零分布一起可视化

visualize(null_dist) +
  shade_p_value(obs_stat = F_hat, direction = "greater")

从零分布和观察到的统计量计算 p 值

null_dist %>%
  get_p_value(obs_stat = F_hat, direction = "greater")
## # A tibble: 1 × 1
##   p_value
##     <dbl>
## 1   0.047

infer实现了一种表达语法来执行与设计框架相一致的统计推断,该软件包没有提供特定统计检验的方法,而是将常见假设检验之间共享的原则整合到一组 4 个主要动词(函数)中,并辅以许多实用程序来可视化并从其输出中提取价值。

不管我们进行的是什么假设检验,我们的问题都是:我们观察到的数据差异是真实存在的还是偶然产生。首先,我们假设数据来自于某一个什么都没有发生的世界(即零假设成立的世界),然后计算给出如果原假设为真,我们观察到的数据可能出现的概率。如果此概率低于某个预定义的显着性水平 𝛼,那么我们可以拒绝我们的原假设。

第一步是使用specify函数 用于指定我们感兴趣的变量是哪些。

gss %>%
  specify(response = age)
## Response: age (numeric)
## # A tibble: 500 × 1
##      age
##    <dbl>
##  1    36
##  2    34
##  3    24
##  4    42
##  5    31
##  6    32
##  7    48
##  8    36
##  9    30
## 10    33
## # … with 490 more rows

看起来,这像是数据框中的某一类,但是查看他的类,是infer。

gss %>%
  specify(response = age) %>%
  class()
## [1] "infer"      "tbl_df"     "tbl"        "data.frame"

如果对两个变量感兴趣,通常有两种方法表示,

# as a formula
gss %>%
  specify(age ~ partyid)
## Dropping unused factor levels DK from the supplied explanatory variable 'partyid'.
## Response: age (numeric)
## Explanatory: partyid (factor)
## # A tibble: 500 × 2
##      age partyid
##    <dbl> <fct>  
##  1    36 ind    
##  2    34 rep    
##  3    24 ind    
##  4    42 ind    
##  5    31 rep    
##  6    32 rep    
##  7    48 dem    
##  8    36 ind    
##  9    30 rep    
## 10    33 dem    
## # … with 490 more rows
gss %>%
  specify(response = age, explanatory = partyid)
## Dropping unused factor levels DK from the supplied explanatory variable 'partyid'.
## Response: age (numeric)
## Explanatory: partyid (factor)
## # A tibble: 500 × 2
##      age partyid
##    <dbl> <fct>  
##  1    36 ind    
##  2    34 rep    
##  3    24 ind    
##  4    42 ind    
##  5    31 rep    
##  6    32 rep    
##  7    48 dem    
##  8    36 ind    
##  9    30 rep    
## 10    33 dem    
## # … with 490 more rows

要对一个比例或比例差异进行推断,则需要使用success参数来指定变量的哪个级别response是成功的。

# specifying for inference on proportions
gss %>%
  specify(response = college, success = "degree")
## Response: college (factor)
## # A tibble: 500 × 1
##    college  
##    <fct>    
##  1 degree   
##  2 no degree
##  3 degree   
##  4 no degree
##  5 degree   
##  6 no degree
##  7 no degree
##  8 degree   
##  9 degree   
## 10 no degree
## # … with 490 more rows

第二步是生成原假设

如果零假设假设两个变量之间是独立的,则设置null=“independence”

gss %>%
  specify(college ~ partyid, success = "degree") %>%
  hypothesize(null = "independence")
## Dropping unused factor levels DK from the supplied explanatory variable 'partyid'.
## Response: college (factor)
## Explanatory: partyid (factor)
## Null Hypothesis: independence
## # A tibble: 500 × 2
##    college   partyid
##    <fct>     <fct>  
##  1 degree    ind    
##  2 no degree rep    
##  3 degree    ind    
##  4 no degree ind    
##  5 degree    rep    
##  6 no degree rep    
##  7 no degree dem    
##  8 degree    ind    
##  9 degree    rep    
## 10 no degree dem    
## # … with 490 more rows

如果要对点估计进行推断,还需要提供p(成功的真实比例,介于 0 和 1 之间)、mu(真实均值)、med(真实中位数)或sigma(真实标准差)之一)。例如,如果零假设是我们的人群每周平均工作小时数是 40,我们会写:

gss %>%
  specify(response = hours) %>%
  hypothesize(null = "point", mu = 40)
## Response: hours (numeric)
## Null Hypothesis: point
## # A tibble: 500 × 1
##    hours
##    <dbl>
##  1    50
##  2    31
##  3    40
##  4    40
##  5    40
##  6    53
##  7    32
##  8    20
##  9    40
## 10    40
## # … with 490 more rows

第三部是生成空分布。一旦我们定义好了原假设,我们就可以基于这个假设构造一个零分布。我们可以使用type参数中提供的几种方法之一来做到这一点。

  1. bootstrap:将为每个重复抽取一个引导样本,其中从输入样本数据中抽取(带替换)大小等于输入样本大小的样本。
  2. permute:对于每个重复,每个输入值将被随机重新分配(不替换)到样本中的新输出值。
  3. simulate:将从理论分布中抽取一个值,并hypothesize()为每个重复指定参数。(此选项目前仅适用于测试点估计。)

继续我们上面的例子,关于每周工作的平均小时数,我们可以这样写:

set.seed(1) # 设置随机数,保证结果的可重复性

gss %>%
  specify(response = hours) %>%
  hypothesize(null = "point", mu = 40) %>%
  generate(reps = 1000, type = "bootstrap") # reps 表示抽样的次数
## Response: hours (numeric)
## Null Hypothesis: point
## # A tibble: 500,000 × 2
## # Groups:   replicate [1,000]
##    replicate hours
##        <int> <dbl>
##  1         1 46.6 
##  2         1 43.6 
##  3         1 38.6 
##  4         1 28.6 
##  5         1 38.6 
##  6         1 38.6 
##  7         1  6.62
##  8         1 78.6 
##  9         1 38.6 
## 10         1 38.6 
## # … with 499,990 more rows

为了为两个变量的独立性生成零分布,我们还可以随机重新排列解释变量和响应变量的配对以打破任何现有的关联。例如,在假设政党隶属关系不受年龄影响的情况下,生成 1000 个可用于创建零分布的重复:

gss %>%
  specify(partyid ~ age) %>%
  hypothesize(null = "independence") %>%
  generate(reps = 1000, type = "permute")
## Dropping unused factor levels DK from the supplied response variable 'partyid'.
## Response: partyid (factor)
## Explanatory: age (numeric)
## Null Hypothesis: independence
## # A tibble: 500,000 × 3
## # Groups:   replicate [1,000]
##    partyid   age replicate
##    <fct>   <dbl>     <int>
##  1 rep        36         1
##  2 rep        34         1
##  3 dem        24         1
##  4 dem        42         1
##  5 dem        31         1
##  6 ind        32         1
##  7 ind        48         1
##  8 rep        36         1
##  9 dem        30         1
## 10 rep        33         1
## # … with 499,990 more rows

第四步是计算汇总统计量

calculate()根据 infer 核心函数的输出计算汇总统计信息。该函数接受一个stat参数,目前是“mean”、“median”、“sum”、“sd”、“prop”、“count”、“diff in mean”、“diff in medians”、“diff in props”、“Chisq”、“F”、“t”、“z”、“slope”或“correlation”。例如,继续我们上面的例子来计算每周平均工作时间的零分布:

gss %>%
  specify(response = hours) %>%
  hypothesize(null = "point", mu = 40) %>%
  generate(reps = 1000, type = "bootstrap") %>%
  calculate(stat = "mean")
## Response: hours (numeric)
## Null Hypothesis: point
## # A tibble: 1,000 × 2
##    replicate  stat
##        <int> <dbl>
##  1         1  39.2
##  2         2  39.1
##  3         3  39.0
##  4         4  39.8
##  5         5  41.4
##  6         6  39.4
##  7         7  39.8
##  8         8  40.4
##  9         9  41.5
## 10        10  40.9
## # … with 990 more rows

这里的输出calculate()向我们展示了 1000 次重复中的每一个的样本统计量

如果要对均值、中位数或比例或 t 和 z 统计量的差异进行推断,则需要提供一个order参数,给出应减去解释变量的顺序。例如,要找出拥有大学学位和没有大学学位的人的平均年龄差异,我们可以这样写:

gss %>%
  specify(age ~ college) %>%
  hypothesize(null = "independence") %>%
  generate(reps = 1000, type = "permute") %>%
  calculate("diff in means", order = c("degree", "no degree"))
## Response: age (numeric)
## Explanatory: college (factor)
## Null Hypothesis: independence
## # A tibble: 1,000 × 2
##    replicate   stat
##        <int>  <dbl>
##  1         1 -2.35 
##  2         2 -0.902
##  3         3  0.403
##  4         4 -0.426
##  5         5  0.482
##  6         6 -0.196
##  7         7  1.33 
##  8         8 -1.07 
##  9         9  1.68 
## 10        10  0.888
## # … with 990 more rows

最后就是可视化visualize(),计算p值get_p_value(),计算置信区间get_confidence_interval()。

为了说明,我们将回到确定每周平均工作小时数是否为 40 小时的示例。

# find the point estimate
obs_mean <- gss %>%
  specify(response = hours) %>%
  calculate(stat = "mean")

# generate a null distribution
null_dist <- gss %>%
  specify(response = hours) %>%
  hypothesize(null = "point", mu = 40) %>%
  generate(reps = 1000, type = "bootstrap") %>%
  calculate(stat = "mean")

null_dist %>% summarise(mean(stat))
## # A tibble: 1 × 1
##   `mean(stat)`
##          <dbl>
## 1         40.0

结果的值和我们的的估计值非常接近,我们可能想知道这种差异是否只是由于随机产生,还是真实存在的。

我们可以对零分布进行数据可视化

null_dist %>%
  visualize() +theme_classic()

我们的样本观察到的统计量在这个分布上的什么位置?我们可以使用obs_stat参数来指定这一点。

null_dist %>%
  visualize() +
  shade_p_value(obs_stat = obs_mean, direction = "two-sided") +theme_classic()

计算p value

p_value <- null_dist %>%
  get_p_value(obs_stat = obs_mean, direction = "two-sided")

p_value
## # A tibble: 1 × 1
##   p_value
##     <dbl>
## 1   0.032

计算置信区间

# generate a distribution like the null distribution, 
# though exclude the null hypothesis from the pipeline
boot_dist <- gss %>%
  specify(response = hours) %>%
  generate(reps = 1000, type = "bootstrap") %>%
  calculate(stat = "mean")

# start with the bootstrap distribution
ci <- boot_dist %>%
  # calculate the confidence interval around the point estimate
  get_confidence_interval(point_estimate = obs_mean,
                          # at the 95% confidence level
                          level = .95,
                          # using the standard error
                          type = "se")

ci
## # A tibble: 1 × 2
##   lower_ci upper_ci
##      <dbl>    <dbl>
## 1     40.1     42.7

对结果进行可视化

boot_dist %>%
  visualize() +
  shade_confidence_interval(endpoints = ci)

计算理论分布

计算观察到的𝑡统计量

# calculate an observed t statistic
obs_t <- gss %>%
  specify(response = hours) %>%
  hypothesize(null = "point", mu = 40) %>%
  calculate(stat = "t")

然后,定义一个理论𝑡分布,我们可以写

# switch out calculate with assume to define a distribution
t_dist <- gss %>%
  specify(response = hours) %>%
  assume(distribution = "t")

可视化

# visualize the theoretical null distribution
visualize(t_dist) +
  shade_p_value(obs_stat = obs_t, direction = "greater")

计算p 值

get_p_value(t_dist, obs_t, "greater")
## # A tibble: 1 × 1
##   p_value
##     <dbl>
## 1  0.0188

置信区间位于数据的尺度上,而不是理论分布的标准化尺度上,因此在使用置信区间时,请务必使用非标准化的观察统计量。

# find the theory-based confidence interval
theor_ci <- 
  get_confidence_interval(
    x = t_dist,
    level = .95,
    point_estimate = obs_mean
  )

theor_ci
## # A tibble: 1 × 2
##   lower_ci upper_ci
##      <dbl>    <dbl>
## 1     40.1     42.7

对数据进行数据可视化

# visualize the theoretical sampling distribution
visualize(t_dist) +
  shade_confidence_interval(theor_ci)

多重回归

为了适应具有多个解释变量的基于随机化的推理,该软件包实现了基于模型拟合的替代工作流程。而不是calculate()从重新采样的数据中获取统计信息。在大多数情况下,您只需切换到calculate()基于fit()您calculate()的工作流程。

假设我们想要分析age,college 对于hours 的差异

null_fits <- gss %>%
  specify(hours ~ age + college) %>%
  hypothesize(null = "independence") %>%
  generate(reps = 1000, type = "permute") %>%
  fit()

null_fits
## # A tibble: 3,000 × 3
## # Groups:   replicate [1,000]
##    replicate term          estimate
##        <int> <chr>            <dbl>
##  1         1 intercept     40.3    
##  2         1 age            0.0166 
##  3         1 collegedegree  1.20   
##  4         2 intercept     41.3    
##  5         2 age            0.00664
##  6         2 collegedegree -0.407  
##  7         3 intercept     42.9    
##  8         3 age           -0.0371 
##  9         3 collegedegree  0.00431
## 10         4 intercept     42.7    
## # … with 2,990 more rows

t检验

单样本的t检验

observed_statistic <- gss %>%
  specify(response = hours) %>%
  calculate(stat = "mean") # 计算统计量 

# generate the null distribution
null_dist_1_sample <- gss %>%
  specify(response = hours) %>%
  hypothesize(null = "point", mu = 40) %>%
  generate(reps = 1000, type = "bootstrap") %>%
  calculate(stat = "mean")

为了了解这些分布是什么样的,以及我们观察到的统计数据落在哪里,我们可以使用visualize():

# visualize the null distribution and test statistic!
null_dist_1_sample %>%
  visualize() + 
  shade_p_value(observed_statistic,
                direction = "two-sided")

计算p 值

# calculate the p value from the test statistic and null distribution
p_value_1_sample <- null_dist_1_sample %>%
  get_p_value(obs_stat = observed_statistic,
              direction = "two-sided")

p_value_1_sample
## # A tibble: 1 × 1
##   p_value
##     <dbl>
## 1   0.038

双样本的t检验

observed_statistic <- gss %>%
  specify(hours ~ college) %>%
  calculate(stat = "diff in means", order = c("degree", "no degree"))

observed_statistic
## Response: hours (numeric)
## Explanatory: college (factor)
## # A tibble: 1 × 1
##    stat
##   <dbl>
## 1  1.54
# generate the null distribution with randomization
null_dist_2_sample <- gss %>%
  specify(hours ~ college) %>%
  hypothesize(null = "independence") %>%
  generate(reps = 1000, type = "permute") %>%
  calculate(stat = "diff in means", order = c("degree", "no degree"))

对数据进行可视化

# visualize the randomization-based null distribution and test statistic!
null_dist_2_sample %>%
  visualize() + 
  shade_p_value(observed_statistic,
                direction = "two-sided")

计算p值

# calculate the p value from the randomization-based null 
# distribution and the observed statistic
p_value_2_sample <- null_dist_2_sample %>%
  get_p_value(obs_stat = observed_statistic,
              direction = "two-sided")

p_value_2_sample
## # A tibble: 1 × 1
##   p_value
##     <dbl>
## 1    0.28

方差分析

# calculate the observed statistic
observed_f_statistic <- gss %>%
  specify(age ~ partyid) %>%
  hypothesize(null = "independence") %>%
  calculate(stat = "F")
## Dropping unused factor levels DK from the supplied explanatory variable 'partyid'.
null_dist <- gss %>%
  specify(age ~ partyid) %>%
  hypothesize(null = "independence") %>%
  generate(reps = 1000, type = "permute") %>%
  calculate(stat = "F")
## Dropping unused factor levels DK from the supplied explanatory variable 'partyid'.

对结果进行可视化

# visualize the null distribution and test statistic!
null_dist %>%
  visualize() + 
  shade_p_value(observed_f_statistic,
                direction = "greater")

计算p 值

# calculate the p value from the observed statistic and null distribution
p_value <- null_dist %>%
  get_p_value(obs_stat = observed_f_statistic,
              direction = "greater")

p_value
## # A tibble: 1 × 1
##   p_value
##     <dbl>
## 1   0.052

卡方检验

# calculate the observed statistic
observed_indep_statistic <- gss %>%
  specify(college ~ finrela) %>%
  hypothesize(null = "independence") %>%
  calculate(stat = "Chisq")


# generate the null distribution using randomization
null_dist_sim <- gss %>%
  specify(college ~ finrela) %>%
  hypothesize(null = "independence") %>%
  generate(reps = 1000, type = "permute") %>%
  calculate(stat = "Chisq")

对结果进行数据可视化

# visualize the null distribution and test statistic!
null_dist_sim %>%
  visualize() + 
  shade_p_value(observed_indep_statistic,
                direction = "greater")

计算p值

# calculate the p value from the observed statistic and null distribution
p_value_independence <- null_dist_sim %>%
  get_p_value(obs_stat = observed_indep_statistic,
              direction = "greater")
## Warning: Please be cautious in reporting a p-value of 0. This result is an
## approximation based on the number of `reps` chosen in the `generate()` step. See
## `?get_p_value()` for more information.
p_value_independence
## # A tibble: 1 × 1
##   p_value
##     <dbl>
## 1       0

拟合优度检验

# calculating the null distribution
observed_gof_statistic <- gss %>%
  specify(response = finrela) %>%
  hypothesize(null = "point",
              p = c("far below average" = 1/6,
                    "below average" = 1/6,
                    "average" = 1/6,
                    "above average" = 1/6,
                    "far above average" = 1/6,
                    "DK" = 1/6)) %>%
  calculate(stat = "Chisq")


# generating a null distribution, assuming each income class is equally likely
null_dist_gof <- gss %>%
  specify(response = finrela) %>%
  hypothesize(null = "point",
              p = c("far below average" = 1/6,
                    "below average" = 1/6,
                    "average" = 1/6,
                    "above average" = 1/6,
                    "far above average" = 1/6,
                    "DK" = 1/6)) %>%
  generate(reps = 1000, type = "draw") %>%
  calculate(stat = "Chisq")

对结果进行数据可视化

# visualize the null distribution and test statistic!
null_dist_gof %>%
  visualize() + 
  shade_p_value(observed_gof_statistic,
                direction = "greater")

计算p 值

# calculate the p-value
p_value_gof <- null_dist_gof %>%
  get_p_value(observed_gof_statistic,
              direction = "greater")
## Warning: Please be cautious in reporting a p-value of 0. This result is an
## approximation based on the number of `reps` chosen in the `generate()` step. See
## `?get_p_value()` for more information.
p_value_gof
## # A tibble: 1 × 1
##   p_value
##     <dbl>
## 1       0

更加详细的资料,参考: https://infer.netlify.app/articles/observed_stat_examples.html

核心函数

  1. specify() 指定响应变量和解释变量
  2. hypothesize() hypothesise() 声明零假设
  3. generate() 生成重采样、排列或模拟
  4. calculate() 计算汇总统计
  5. fit() 拟合线性模型以推断对象
  6. assume() 定义理论分布

帮助函数

  1. visualize() visualise() 可视化统计推断
  2. get_p_value() get_pvalue() 计算 p 值
  3. get_confidence_interval() get_ci() 计算置信区间
  4. shade_p_value() shade_pvalue() 超出观察统计量的阴影直方图区域
  5. shade_confidence_interval() shade_ci() 添加有关置信区间的信息

其他函数

  1. observe() 计算观察到的统计数据
  2. chisq_test() 整齐的卡方检验
  3. prop_test() 整齐的比例测试
  4. t_test() 整齐的 t 检验
  5. chisq_stat() 整齐的卡方检验统计量
  6. t_stat() 整齐的 t 检验统计量

corrr

处理相关性的R包。

library(corrr)
## 
## Attaching package: 'corrr'
## The following object is masked from 'package:skimr':
## 
##     focus
x <- correlate(iris[-5])
## 
## Correlation method: 'pearson'
## Missing treated using: 'pairwise.complete.obs'
x
## # A tibble: 4 × 5
##   term         Sepal.Length Sepal.Width Petal.Length Petal.Width
##   <chr>               <dbl>       <dbl>        <dbl>       <dbl>
## 1 Sepal.Length       NA          -0.118        0.872       0.818
## 2 Sepal.Width        -0.118      NA           -0.428      -0.366
## 3 Petal.Length        0.872      -0.428       NA           0.963
## 4 Petal.Width         0.818      -0.366        0.963      NA

常用函数

  1. as_cordf() 数据转换
  2. as_matrix() 将相关数据框转换为矩阵格式
  3. colpair_map() 将函数应用于数据框中的所有列对
  4. correlate() 相关数据框
  5. dice() 返回仅包含选定字段的相关表
  6. fashion() 制作用于打印的相关数据框。
  7. first_col() 将第一列添加到 data.frame
  8. focus() focus_() 关注相关数据框的部分。
  9. focus_if() 有条件地聚焦相关数据框
  10. network_plot() 相关数据框的网络图
  11. pair_n() 成对完整案例的数量。
  12. rearrange() 重新排列相关数据框
  13. retract() 从拉伸的相关表创建数据框
  14. rplot() 绘制相关数据框。
  15. shave() 剃掉上/下三角形。
  16. stretch() 将相关数据帧拉伸为长格式。

spatialsample

这个包提供了空间重采样的一些方法。 我们看一个简单的例子:

library(spatialsample)
data("ames", package = "modeldata")

ames %>% select(Latitude,Longitude) 
## # A tibble: 2,930 × 2
##    Latitude Longitude
##       <dbl>     <dbl>
##  1     42.1     -93.6
##  2     42.1     -93.6
##  3     42.1     -93.6
##  4     42.1     -93.6
##  5     42.1     -93.6
##  6     42.1     -93.6
##  7     42.1     -93.6
##  8     42.1     -93.6
##  9     42.1     -93.6
## 10     42.1     -93.6
## # … with 2,920 more rows
set.seed(1234)
folds <- spatial_clustering_cv(ames, coords = c("Latitude", "Longitude"), v = 5)

folds
## #  5-fold spatial cross-validation 
## # A tibble: 5 × 2
##   splits             id   
##   <list>             <chr>
## 1 <split [2332/598]> Fold1
## 2 <split [2187/743]> Fold2
## 3 <split [2570/360]> Fold3
## 4 <split [2118/812]> Fold4
## 5 <split [2513/417]> Fold5

这个包中的关键函数是spatial_clustering_cv()。

tidypredict

这个包的主要作用是将R模型转换成为sql语句。

library(tidypredict)
model <- lm(mpg ~ wt + cyl, data = mtcars)

tidypredict_sql(model, dbplyr::simulate_mssql())
## <SQL> 39.6862614802529 + (`wt` * -3.19097213898374) + (`cyl` * -1.5077949682598)

tidypredict 工作流程

  1. 使用基本 R 模型或支持的模型中列出的包中的一个来拟合模型

  2. tidypredict读取模型,并创建一个列表对象,其中包含运行预测所需的组件

  3. tidypredict基于列表对象构建 R 公式

  4. dplyr评估由创建的公式tidypredict

  5. dplyr将公式转换为 SQL 语句或任何其他接口。

  6. 数据库执行由创建的 SQL 语句dplyr

  7. tidypredict包涉及到的函数不多。

  8. tidypredict_fit() 返回计算预测的 R 公式

  9. tidypredict_sql() 返回基于公式的 SQL 查询tidypredict_fit()

  10. tidypredict_to_column() 使用公式添加一个新列tidypredict_fit()

  11. tidypredict_test() tidyverse针对模型的本机predict()函数测试预测

  12. tidypredict_interval() 与间隔相同tidypredict_fit()(仅适用于lmand glm)

  13. tidypredict_sql_interval() 与间隔相同tidypredict_sql()(仅适用于lmand glm)

  14. parse_model() 基于 R 模型创建列表规范

  15. as_parsed_model() 准备一个要被识别为已解析模型的对象

支持的模型包括

  1. 线性回归 -lm()
  2. 广义线性模型 -glm()
  3. 随机森林模型 -randomForest::randomForest()
  4. 随机森林模型,通过ranger-ranger::ranger()
  5. MARS 模型 -earth::earth()
  6. XGBoost 模型 -xgboost::xgb.Booster.complete()
  7. cub模型 -Cubist::cubist()
  8. 树模型,通过partykit-partykit::ctree()

tidypredict支持通过parsnip接口安装的模型。目前已确认工作的tidypredict有:

  1. lm()- parsnip:linear_reg()以“lm”作为引擎。
  2. randomForest::randomForest()- parsnip:rand_forest()以“randomForest”作为引擎。
  3. ranger::ranger()- parsnip:rand_forest()以“ranger”为引擎。
  4. earth::earth()- parsnip:mars()以“earth”为引擎。

保存和重新加载模型

假设我们的模型如下所示。

model <- lm(mpg ~ (wt + disp) * cyl, data = mtcars)

使用parse_model解析模型。parse_model()函数将返回一个 R 列表对象,其中包含生成预测计算所需的所有信息

library(tidypredict)

parsed <- parse_model(model)
str(parsed)
## List of 2
##  $ general:List of 6
##   ..$ model   : chr "lm"
##   ..$ version : num 2
##   ..$ type    : chr "regression"
##   ..$ residual: int 26
##   ..$ sigma2  : num 5.91
##   ..$ is_glm  : num 0
##  $ terms  :List of 6
##   ..$ :List of 5
##   .. ..$ label       : chr "(Intercept)"
##   .. ..$ coef        : num 53.5
##   .. ..$ is_intercept: num 1
##   .. ..$ fields      :List of 1
##   .. .. ..$ :List of 2
##   .. .. .. ..$ type: chr "ordinary"
##   .. .. .. ..$ col : chr "(Intercept)"
##   .. ..$ qr          :List of 6
##   .. .. ..$ qr_1: num -0.177
##   .. .. ..$ qr_2: num -0.591
##   .. .. ..$ qr_3: num 0.413
##   .. .. ..$ qr_4: num -0.806
##   .. .. ..$ qr_5: num 2.35
##   .. .. ..$ qr_6: num -0.581
##   ..$ :List of 5
##   .. ..$ label       : chr "wt"
##   .. ..$ coef        : num -6.38
##   .. ..$ is_intercept: num 0
##   .. ..$ fields      :List of 1
##   .. .. ..$ :List of 2
##   .. .. .. ..$ type: chr "ordinary"
##   .. .. .. ..$ col : chr "wt"
##   .. ..$ qr          :List of 6
##   .. .. ..$ qr_1: num 0
##   .. .. ..$ qr_2: num 0.184
##   .. .. ..$ qr_3: num -0.354
##   .. .. ..$ qr_4: num 0.0373
##   .. .. ..$ qr_5: num -0.903
##   .. .. ..$ qr_6: num 1.47
##   ..$ :List of 5
##   .. ..$ label       : chr "disp"
##   .. ..$ coef        : num -0.0458
##   .. ..$ is_intercept: num 0
##   .. ..$ fields      :List of 1
##   .. .. ..$ :List of 2
##   .. .. .. ..$ type: chr "ordinary"
##   .. .. .. ..$ col : chr "disp"
##   .. ..$ qr          :List of 6
##   .. .. ..$ qr_1: num 0
##   .. .. ..$ qr_2: num 0
##   .. .. ..$ qr_3: num 0.00315
##   .. .. ..$ qr_4: num -0.0033
##   .. .. ..$ qr_5: num -0.00208
##   .. .. ..$ qr_6: num -0.0257
##   ..$ :List of 5
##   .. ..$ label       : chr "cyl"
##   .. ..$ coef        : num -3.63
##   .. ..$ is_intercept: num 0
##   .. ..$ fields      :List of 1
##   .. .. ..$ :List of 2
##   .. .. .. ..$ type: chr "ordinary"
##   .. .. .. ..$ col : chr "cyl"
##   .. ..$ qr          :List of 6
##   .. .. ..$ qr_1: num 0
##   .. .. ..$ qr_2: num 0
##   .. .. ..$ qr_3: num 0
##   .. .. ..$ qr_4: num 0.234
##   .. .. ..$ qr_5: num -0.354
##   .. .. ..$ qr_6: num 0.103
##   ..$ :List of 5
##   .. ..$ label       : chr "wt:cyl"
##   .. ..$ coef        : num 0.536
##   .. ..$ is_intercept: num 0
##   .. ..$ fields      :List of 2
##   .. .. ..$ :List of 2
##   .. .. .. ..$ type: chr "ordinary"
##   .. .. .. ..$ col : chr "wt"
##   .. .. ..$ :List of 2
##   .. .. .. ..$ type: chr "ordinary"
##   .. .. .. ..$ col : chr "cyl"
##   .. ..$ qr          :List of 6
##   .. .. ..$ qr_1: num 0
##   .. .. ..$ qr_2: num 0
##   .. .. ..$ qr_3: num 0
##   .. .. ..$ qr_4: num 0
##   .. .. ..$ qr_5: num 0.152
##   .. .. ..$ qr_6: num -0.202
##   ..$ :List of 5
##   .. ..$ label       : chr "disp:cyl"
##   .. ..$ coef        : num 0.00541
##   .. ..$ is_intercept: num 0
##   .. ..$ fields      :List of 2
##   .. .. ..$ :List of 2
##   .. .. .. ..$ type: chr "ordinary"
##   .. .. .. ..$ col : chr "disp"
##   .. .. ..$ :List of 2
##   .. .. .. ..$ type: chr "ordinary"
##   .. .. .. ..$ col : chr "cyl"
##   .. ..$ qr          :List of 6
##   .. .. ..$ qr_1: num 0
##   .. .. ..$ qr_2: num 0
##   .. .. ..$ qr_3: num 0
##   .. .. ..$ qr_4: num 0
##   .. .. ..$ qr_5: num 0
##   .. .. ..$ qr_6: num 0.00334
##  - attr(*, "class")= chr [1:3] "parsed_model" "pm_regression" "list"
class(parsed)
## [1] "parsed_model"  "pm_regression" "list"

通常,我们将 R 模型对象传递给函数,例如:tidypredict_fit()和tidypredict_sql()。这些函数还接受先前解析的模型。

tidypredict_fit(model)
## 53.5256637443325 + (wt * -6.38154597431604) + (disp * -0.0458426921825966) + 
##     (cyl * -3.6302556793944) + (wt * cyl * 0.535604359938273) + 
##     (disp * cyl * 0.00540618405824797)
tidypredict_fit(parsed)
## 53.5256637443325 + (wt * -6.38154597431604) + (disp * -0.0458426921825966) + 
##     (cyl * -3.6302556793944) + (wt * cyl * 0.535604359938273) + 
##     (disp * cyl * 0.00540618405824797)

保存模型

保存模型非常简单,使用包比如yaml将模型对象写成 YAML 文件。

library(yaml)

write_yaml(parsed, "my_model.yml")

重新加载模型

在新的 R 会话中,我们可以将 YAML 文件读入我们的环境。

library(tidypredict)
library(yaml)

loaded_model <- read_yaml("my_model.yml")

loaded_model <- as_parsed_model(loaded_model)

loaded_model %>% tidypredict_fit()
## 53.5256637 + (wt * -6.381546) + (disp * -0.0458427) + (cyl * 
##     -3.6302557) + (wt * cyl * 0.5356044) + (disp * cyl * 0.0054062)
tidypredict_sql(loaded_model, dbplyr::simulate_odbc())
## <SQL> 53.5256637 + (`wt` * -6.381546) + (`disp` * -0.0458427) + (`cyl` * -3.6302557) + (`wt` * `cyl` * 0.5356044) + (`disp` * `cyl` * 0.0054062)

使用broom将结果标准化。

tidy(loaded_model)
## # A tibble: 6 × 2
##   term        estimate
##   <chr>          <dbl>
## 1 (Intercept) 53.5    
## 2 wt          -6.38   
## 3 disp        -0.0458 
## 4 cyl         -3.63   
## 5 wt:cyl       0.536  
## 6 disp:cyl     0.00541

数据库回写

首先,数据在内存中准备好。本文将使用nycflights13::flights数据,并进行一些修改

library(dplyr)
library(tidypredict)
library(randomForest)
## Warning: package 'randomForest' was built under R version 4.1.2
## randomForest 4.7-1
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ranger':
## 
##     importance
## The following object is masked from 'package:ggplot2':
## 
##     margin
## The following object is masked from 'package:dplyr':
## 
##     combine
library(dbplyr)
## 
## Attaching package: 'dbplyr'
## The following objects are masked from 'package:dplyr':
## 
##     ident, sql
flights_table <- nycflights13::flights %>%
  mutate(
    current_score = 0, 
    flight_id = row_number()
    ) 

使用 . 创建一个新数据库RSQLite。

library(DBI)

con <- dbConnect(RSQLite::SQLite(), path = ":memory:")

con
## <SQLiteConnection>
##   Path: 
##   Extensions: TRUE
db_fligths <- copy_to(con,flights_table) # 将flights_table数据集倒入数据库

从数据库中下载一个样本用于建模。这个例子已经选择了需要的变量。

df <- db_fligths %>%
  select(dep_delay, hour, distance) %>%
  head(1000) %>%
  collect() 

使用拟合线性模型lm()

model <- lm(dep_delay ~ ., data = df)

我们将模型结果更新到数据库,首先生成sql

library(dbplyr)

update_statement <- build_sql("UPDATE flights_table SET current_score  = ", tidypredict_sql(model, con = con), con = con)

update_statement
## <SQL> UPDATE flights_table SET current_score  = -3.5984422918702 + (`hour` * 1.38710560882252) + (`distance` * -0.00307606912118567)

然后运行sql

dbSendQuery(con, update_statement)
## <SQLiteResult>
##   SQL  UPDATE flights_table SET current_score  = -3.5984422918702 + (`hour` * 1.38710560882252) + (`distance` * -0.00307606912118567)
##   ROWS Fetched: 0 [complete]
##        Changed: 336776

我们查看更新的数据:

db_fligths %>% select(current_score)
## Warning: Closing open result set, pending rows
## # Source:   lazy query [?? x 1]
## # Database: sqlite 3.37.0 []
##    current_score
##            <dbl>
##  1       -0.969 
##  2       -1.02  
##  3       -0.0128
##  4       -1.51  
##  5        2.38  
##  6        1.13  
##  7        1.45  
##  8        4.02  
##  9        1.82  
## 10        2.47  
## # … with more rows

更多操作数据库的方法参见:https://db.rstudio.com/r-packages/dbi/

这个包简单来说就是首先创建好模型,然后使用tidypredict_sql(),将模型转换成为sql

modeldb

这个包允许在数据库中拟合模型。

线性回归

首先创建数据库的链接

library(modeldb)
con <- DBI::dbConnect(RSQLite::SQLite(), path = ":memory:")
RSQLite::initExtension(con)
dplyr::copy_to(con, mtcars) # 

训练模型

library(dplyr)

tbl(con, "mtcars") %>%
  select(wt, mpg, qsec) %>%
  linear_regression_db(wt)
## # A tibble: 1 × 3
##   `(Intercept)`    mpg  qsec
##           <dbl>  <dbl> <dbl>
## 1          4.12 -0.156 0.125

目前只支持线性回归和k均值聚类

stacks

这个包用于模型堆叠,就是将结果组合起来。步骤是:

  1. 使用来自rsample、parsnip、工作流、recipes和tune的功能定义候选集成成员
  2. 使用stacks()函数初始化一个data_stack对象
  3. 使用add_candidates()函数迭代地将候选集成成员添加到data_stackwith
  4. 使用blend_predictions()评估合并结果
  5. 使用fit_members()拟合具有非零堆叠系数的模型
  6. 预测新数据predict()

回归模型

加载包

library(tidymodels)
library(stacks)
## Warning: package 'stacks' was built under R version 4.1.2
library(dplyr)
library(purrr)

对数据进行处理

data("tree_frogs")

# subset the data
tree_frogs <- tree_frogs %>%
  filter(!is.na(latency)) %>%
  select(-c(clutch, hatched))

分析age和lateny的关系

library(ggplot2)

ggplot(tree_frogs) +
  aes(x = age, y = latency, color = treatment) +
  geom_point() +
  labs(x = "Embryo Age (s)", y = "Time to Hatch (s)", col = "Treatment")

我们将首先拆分训练数据、生成重采样并设置每个模型定义将使用的一些选项。

# some setup: resampling and a basic recipe
set.seed(1)
tree_frogs_split <- initial_split(tree_frogs)
tree_frogs_train <- training(tree_frogs_split)
tree_frogs_test  <- testing(tree_frogs_split)

set.seed(1)
folds <- rsample::vfold_cv(tree_frogs_train, v = 5)

tree_frogs_rec <- 
  recipe(latency ~ ., data = tree_frogs_train)

metric <- metric_set(rmse)

控制参数

ctrl_grid <- control_stack_grid()
ctrl_res <- control_stack_resamples()

构建三个模型,K-最近邻模型(需要调整超参数)、线性模型和支持向量机模型(同样需要调整超参数)

首先构建k均值模型。

knn_spec <-
  nearest_neighbor(
    mode = "regression", 
    neighbors = tune("k")
  ) %>%
  set_engine("kknn")

knn_spec
## K-Nearest Neighbor Model Specification (regression)
## 
## Main Arguments:
##   neighbors = tune("k")
## 
## Computational engine: kknn

添加recipe

# extend the recipe
knn_rec <-
  tree_frogs_rec %>%
  step_dummy(all_nominal()) %>%
  step_zv(all_predictors(), skip = TRUE) %>%
  step_meanimpute(all_numeric(), skip = TRUE) %>%
  step_normalize(all_numeric(), skip = TRUE)
## Warning: `step_meanimpute()` was deprecated in recipes 0.1.16.
## Please use `step_impute_mean()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
knn_rec
## Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor          4
## 
## Operations:
## 
## Dummy variables from all_nominal()
## Zero variance filter on all_predictors()
## Mean imputation for all_numeric()
## Centering and scaling for all_numeric()

生成一个workflow 对象。

# add both to a workflow
knn_wflow <- 
  workflow() %>% 
  add_model(knn_spec) %>%
  add_recipe(knn_rec)

knn_wflow
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: nearest_neighbor()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
## 
## • step_dummy()
## • step_zv()
## • step_impute_mean()
## • step_normalize()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## K-Nearest Neighbor Model Specification (regression)
## 
## Main Arguments:
##   neighbors = tune("k")
## 
## Computational engine: kknn

设置网格搜索。

# tune k and fit to the 5-fold cv
set.seed(2020)
knn_res <- 
  tune_grid(
    knn_wflow,
    resamples = folds,
    metrics = metric,
    grid = 4,
    control = ctrl_grid
  )

knn_res
## # Tuning results
## # 5-fold cross-validation 
## # A tibble: 5 × 5
##   splits           id    .metrics         .notes           .predictions      
##   <list>           <chr> <list>           <list>           <list>            
## 1 <split [343/86]> Fold1 <tibble [4 × 5]> <tibble [0 × 3]> <tibble [344 × 5]>
## 2 <split [343/86]> Fold2 <tibble [4 × 5]> <tibble [0 × 3]> <tibble [344 × 5]>
## 3 <split [343/86]> Fold3 <tibble [4 × 5]> <tibble [0 × 3]> <tibble [344 × 5]>
## 4 <split [343/86]> Fold4 <tibble [4 × 5]> <tibble [0 × 3]> <tibble [344 × 5]>
## 5 <split [344/85]> Fold5 <tibble [4 × 5]> <tibble [0 × 3]> <tibble [340 × 5]>

构建线性模型

# create a model definition
lin_reg_spec <-
  linear_reg() %>%
  set_engine("lm")

# extend the recipe
lin_reg_rec <-
  tree_frogs_rec %>%
  step_dummy(all_nominal()) %>%
  step_zv(all_predictors(), skip = TRUE)

# add both to a workflow
lin_reg_wflow <- 
  workflow() %>%
  add_model(lin_reg_spec) %>%
  add_recipe(lin_reg_rec)

# fit to the 5-fold cv
set.seed(2020)
lin_reg_res <- 
  fit_resamples(
    lin_reg_wflow,
    resamples = folds,
    metrics = metric,
    control = ctrl_res
  )

lin_reg_res
## # Resampling results
## # 5-fold cross-validation 
## # A tibble: 5 × 5
##   splits           id    .metrics         .notes           .predictions     
##   <list>           <chr> <list>           <list>           <list>           
## 1 <split [343/86]> Fold1 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [86 × 4]>
## 2 <split [343/86]> Fold2 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [86 × 4]>
## 3 <split [343/86]> Fold3 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [86 × 4]>
## 4 <split [343/86]> Fold4 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [86 × 4]>
## 5 <split [344/85]> Fold5 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [85 × 4]>

构建支持向量模型

# create a model definition
svm_spec <- 
  svm_rbf(
    cost = tune("cost"), 
    rbf_sigma = tune("sigma")
  ) %>%
  set_engine("kernlab") %>%
  set_mode("regression")

# extend the recipe
svm_rec <-
  tree_frogs_rec %>%
  step_dummy(all_nominal()) %>%
  step_zv(all_predictors(), skip = TRUE) %>%
  step_meanimpute(all_numeric(), skip = TRUE) %>%
  step_corr(all_predictors(), skip = TRUE) %>%
  step_normalize(all_numeric(), skip = TRUE)

# add both to a workflow
svm_wflow <- 
  workflow() %>% 
  add_model(svm_spec) %>%
  add_recipe(svm_rec)

# tune cost and sigma and fit to the 5-fold cv
set.seed(2020)
svm_res <- 
  tune_grid(
    svm_wflow, 
    resamples = folds, 
    grid = 6,
    metrics = metric,
    control = ctrl_grid
  )
## Warning: package 'kernlab' was built under R version 4.1.2
svm_res
## # Tuning results
## # 5-fold cross-validation 
## # A tibble: 5 × 5
##   splits           id    .metrics         .notes           .predictions      
##   <list>           <chr> <list>           <list>           <list>            
## 1 <split [343/86]> Fold1 <tibble [6 × 6]> <tibble [0 × 3]> <tibble [516 × 6]>
## 2 <split [343/86]> Fold2 <tibble [6 × 6]> <tibble [0 × 3]> <tibble [516 × 6]>
## 3 <split [343/86]> Fold3 <tibble [6 × 6]> <tibble [0 × 3]> <tibble [516 × 6]>
## 4 <split [343/86]> Fold4 <tibble [6 × 6]> <tibble [0 × 3]> <tibble [516 × 6]>
## 5 <split [344/85]> Fold5 <tibble [6 × 6]> <tibble [0 × 3]> <tibble [510 × 6]>

准备将模型堆叠起来。 首先初始化。

stacks()
## # A data stack with 0 model definitions and 0 candidate members.

使用add_candidates()函数将集成成员添加到堆栈中。

tree_frogs_data_st <- 
  stacks() %>%
  add_candidates(knn_res) %>%
  add_candidates(lin_reg_res) %>%
  add_candidates(svm_res)

tree_frogs_data_st
## # A data stack with 3 model definitions and 11 candidate members:
## #   knn_res: 4 model configurations
## #   lin_reg_res: 1 model configuration
## #   svm_res: 6 model configurations
## # Outcome: latency (numeric)

在底层,一个data_stack对象实际上只是一个带有一些额外属性的 tibble。查看实际数据:

as_tibble(tree_frogs_data_st)
## # A tibble: 429 × 12
##    latency knn_res_1_1 knn_res_1_2 knn_res_1_3 knn_res_1_4 lin_reg_res_1_1
##      <dbl>       <dbl>       <dbl>       <dbl>       <dbl>           <dbl>
##  1     142      -0.496      -0.478      -0.492      -0.494           114. 
##  2      79      -0.381      -0.446      -0.542      -0.553            78.6
##  3      50      -0.311      -0.352      -0.431      -0.438            81.5
##  4      68      -0.312      -0.368      -0.463      -0.473            78.6
##  5      64      -0.496      -0.478      -0.492      -0.494            36.5
##  6      52      -0.391      -0.412      -0.473      -0.482           124. 
##  7      39      -0.523      -0.549      -0.581      -0.587            35.2
##  8      46      -0.523      -0.549      -0.581      -0.587            37.1
##  9     137      -0.287      -0.352      -0.447      -0.456            78.8
## 10      73      -0.523      -0.549      -0.581      -0.587            38.8
## # … with 419 more rows, and 6 more variables: svm_res_1_1 <dbl>,
## #   svm_res_1_4 <dbl>, svm_res_1_3 <dbl>, svm_res_1_5 <dbl>, svm_res_1_2 <dbl>,
## #   svm_res_1_6 <dbl>

拟合堆叠stack

tree_frogs_model_st <-
  tree_frogs_data_st %>%
  blend_predictions()

tree_frogs_model_st
## ── A stacked ensemble model ─────────────────────────────────────
## 
## Out of 11 possible candidate members, the ensemble retained 4.
## Penalty: 0.1.
## Mixture: 1.
## 
## The 4 highest weighted members are:
## # A tibble: 4 × 3
##   member          type              weight
##   <chr>           <chr>              <dbl>
## 1 svm_res_1_6     svm_rbf          352.   
## 2 knn_res_1_4     nearest_neighbor  76.8  
## 3 svm_res_1_2     svm_rbf           27.1  
## 4 lin_reg_res_1_1 linear_reg         0.924
## 
## Members have not yet been fitted with `fit_members()`.

该blend_predictions函数通过在数据堆栈上拟合 LASSO 模型,使用每个候选成员的预测来预测真实的评估集结果,从而确定成员模型输出最终将如何组合到最终预测中。

为了确保我们在最小化成员数量和优化性能之间做出正确的权衡,对结果进行数据可视化。

theme_set(theme_bw())
autoplot(tree_frogs_model_st)

我们可以在整个训练集上拟合具有非零堆叠系数的候选者。

tree_frogs_model_st <-
  tree_frogs_model_st %>%
  fit_members()

tree_frogs_model_st
## ── A stacked ensemble model ─────────────────────────────────────
## 
## Out of 11 possible candidate members, the ensemble retained 4.
## Penalty: 0.1.
## Mixture: 1.
## 
## The 4 highest weighted members are:
## # A tibble: 4 × 3
##   member          type              weight
##   <chr>           <chr>              <dbl>
## 1 svm_res_1_6     svm_rbf          352.   
## 2 knn_res_1_4     nearest_neighbor  76.8  
## 3 svm_res_1_2     svm_rbf           27.1  
## 4 lin_reg_res_1_1 linear_reg         0.924

要确定哪些模型配置分配了哪些堆叠系数,可以使用collect_parameters()函数。

collect_parameters(tree_frogs_model_st, "svm_res")
## # A tibble: 6 × 4
##   member         cost    sigma  coef
##   <chr>         <dbl>    <dbl> <dbl>
## 1 svm_res_1_1 0.00143 6.64e- 9   0  
## 2 svm_res_1_2 3.59    3.95e- 4  27.1
## 3 svm_res_1_3 0.0978  1.81e- 2   0  
## 4 svm_res_1_4 0.00849 2.16e-10   0  
## 5 svm_res_1_5 0.256   4.54e- 1   0  
## 6 svm_res_1_6 7.64    3.16e- 7 352.

使用新数据进行预测!

tree_frogs_test <- 
  tree_frogs_test %>%
  bind_cols(predict(tree_frogs_model_st, .))

使用该type = “members”参数从每个集成成员中生成预测。

member_preds <- 
  tree_frogs_test %>%
  select(latency) %>%
  bind_cols(predict(tree_frogs_model_st, tree_frogs_test, members = TRUE))

member_preds
## # A tibble: 143 × 6
##    latency .pred knn_res_1_4 lin_reg_res_1_1 svm_res_1_2 svm_res_1_6
##      <dbl> <dbl>       <dbl>           <dbl>       <dbl>       <dbl>
##  1     180 134.       -0.504           138.       -0.126      -0.327
##  2      39  86.4      -0.448            82.4      -0.126      -0.327
##  3     224 113.       -0.504           116.       -0.126      -0.327
##  4      33  39.0      -0.504            35.8      -0.126      -0.327
##  5      94 108.       -0.504           111.       -0.126      -0.327
##  6      19  41.7      -0.504            38.8      -0.126      -0.327
##  7     126 119.       -0.504           123.       -0.126      -0.327
##  8      79  86.2      -0.448            82.3      -0.126      -0.327
##  9      46  41.7      -0.504            38.7      -0.126      -0.327
## 10      65  83.0      -0.448            78.8      -0.126      -0.327
## # … with 133 more rows

现在,评估每个模型的均方根误差:

map_dfr(member_preds, rmse, truth = latency, data = member_preds) %>%
  mutate(member = colnames(member_preds))
## # A tibble: 6 × 4
##   .metric .estimator .estimate member         
##   <chr>   <chr>          <dbl> <chr>          
## 1 rmse    standard         0   latency        
## 2 rmse    standard        55.3 .pred          
## 3 rmse    standard       114.  knn_res_1_4    
## 4 rmse    standard        55.5 lin_reg_res_1_1
## 5 rmse    standard       114.  svm_res_1_2    
## 6 rmse    standard       114.  svm_res_1_6

正如我们所见,stacked ensemble 优于每个成员模型

分类模型

准备相关的R包。

library(tidymodels)
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ stringr 1.4.0     ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x kernlab::alpha()        masks ggplot2::alpha(), scales::alpha()
## x readr::col_factor()     masks scales::col_factor()
## x randomForest::combine() masks dplyr::combine()
## x kernlab::cross()        masks purrr::cross()
## x purrr::discard()        masks scales::discard()
## x Matrix::expand()        masks tidyr::expand()
## x dplyr::filter()         masks stats::filter()
## x stringr::fixed()        masks recipes::fixed()
## x dbplyr::ident()         masks dplyr::ident()
## x dplyr::lag()            masks stats::lag()
## x randomForest::margin()  masks ggplot2::margin()
## x Matrix::pack()          masks tidyr::pack()
## x readr::spec()           masks yardstick::spec()
## x dbplyr::sql()           masks dplyr::sql()
## x Matrix::unpack()        masks tidyr::unpack()
library(stacks)

首先对数据进行简单的处理

data("tree_frogs")

# subset the data
tree_frogs <- tree_frogs %>%
  select(-c(clutch, latency))

首先,拆分训练数据,生成重采样,并设置每个模型定义将使用的一些选项。

# some setup: resampling and a basic recipe
set.seed(1)

tree_frogs_split <- initial_split(tree_frogs)
tree_frogs_train <- training(tree_frogs_split)
tree_frogs_test  <- testing(tree_frogs_split)

folds <- rsample::vfold_cv(tree_frogs_train, v = 5)

tree_frogs_rec <- 
  recipe(reflex ~ ., data = tree_frogs_train) %>%
  step_dummy(all_nominal(), -reflex) %>%
  step_zv(all_predictors())

tree_frogs_wflow <- 
  workflow() %>% 
  add_recipe(tree_frogs_rec)

我们还需要使用与数字响应设置相同的控制设置:

ctrl_grid <- control_stack_grid()

我们将定义两种不同的模型定义来尝试预测reflex——随机森林和神经网络。

开始构建随机森林模型

rand_forest_spec <- 
  rand_forest(
    mtry = tune(),
    min_n = tune(),
    trees = 500
  ) %>%
  set_mode("classification") %>%
  set_engine("ranger")

rand_forest_wflow <-
  tree_frogs_wflow %>%
  add_model(rand_forest_spec)

rand_forest_res <- 
  tune_grid(
    object = rand_forest_wflow, 
    resamples = folds, 
    grid = 10,
    control = ctrl_grid
  )
## i Creating pre-processing data to finalize unknown parameter: mtry

构建神经网络模型

nnet_spec <-
  mlp(hidden_units = tune(), penalty = tune(), epochs = tune()) %>%
  set_mode("classification") %>%
  set_engine("nnet")

nnet_rec <- 
  tree_frogs_rec %>% 
  step_normalize(all_predictors())

nnet_wflow <- 
  tree_frogs_wflow %>%
  add_model(nnet_spec)

nnet_res <-
  tune_grid(
    object = nnet_wflow, 
    resamples = folds, 
    grid = 10,
    control = ctrl_grid
  )

将模型堆叠在一起

tree_frogs_model_st <- 
  # initialize the stack
  stacks() %>%
  # add candidate members
  add_candidates(rand_forest_res) %>%
  add_candidates(nnet_res) %>%
  # determine how to combine their predictions
  blend_predictions() %>%
  # fit the candidates with nonzero stacking coefficients
  fit_members()
## ! Bootstrap01: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -67); ...
## ! Bootstrap02: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -67); ...
## ! Bootstrap03: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -74); ...
## ! Bootstrap04: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -64); ...
## ! Bootstrap08: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -56); ...
## ! Bootstrap09: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -68); ...
## ! Bootstrap10: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -59); ...
## ! Bootstrap12: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -57); ...
## ! Bootstrap13: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -78); ...
## ! Bootstrap14: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -87); ...
## ! Bootstrap15: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -63); ...
## ! Bootstrap16: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -64); ...
## ! Bootstrap18: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -80); ...
## ! Bootstrap22: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -62); ...
## ! Bootstrap25: preprocessor 1/1, model 1/1: from glmnet Fortran code (error code -64); ...
tree_frogs_model_st
## ── A stacked ensemble model
## ─────────────────────────────────────
## 
## Out of 40 possible candidate members, the ensemble retained 28.
## Penalty: 1e-05.
## Mixture: 1.
## Across the 3 classes, there are an average of 9.33 coefficients per class.
## 
## The 10 highest weighted member classes are:
## # A tibble: 10 × 4
##    member                         type        weight class
##    <chr>                          <chr>        <dbl> <chr>
##  1 .pred_full_nnet_res_1_05       mlp         1209.  low  
##  2 .pred_full_nnet_res_1_09       mlp          319.  low  
##  3 .pred_mid_nnet_res_1_09        mlp           80.4 mid  
##  4 .pred_full_nnet_res_1_04       mlp           68.0 mid  
##  5 .pred_full_nnet_res_1_04       mlp           67.8 full 
##  6 .pred_mid_rand_forest_res_1_01 rand_forest   46.1 low  
##  7 .pred_mid_rand_forest_res_1_05 rand_forest   32.7 mid  
##  8 .pred_mid_rand_forest_res_1_07 rand_forest   25.3 low  
##  9 .pred_mid_rand_forest_res_1_10 rand_forest   24.3 mid  
## 10 .pred_full_nnet_res_1_07       mlp           22.8 low

为了确保我们在最小化成员数量和优化性能之间做出正确的权衡,我们可以使用autoplot函数

theme_set(theme_bw())
autoplot(tree_frogs_model_st)

为了更直接地显示关系:

autoplot(tree_frogs_model_st, type = "members")

要确定哪些模型配置分配了哪些堆叠系数,我们可以使用以下collect_parameters()函数:

collect_parameters(tree_frogs_model_st, "rand_forest_res")
## # A tibble: 60 × 6
##    member                mtry min_n class terms                             coef
##    <chr>                <int> <int> <chr> <chr>                            <dbl>
##  1 rand_forest_res_1_01     1    26 low   .pred_mid_rand_forest_res_1_01  46.1  
##  2 rand_forest_res_1_01     1    26 low   .pred_full_rand_forest_res_1_01  0    
##  3 rand_forest_res_1_01     1    26 mid   .pred_mid_rand_forest_res_1_01   0    
##  4 rand_forest_res_1_01     1    26 mid   .pred_full_rand_forest_res_1_01  0.406
##  5 rand_forest_res_1_01     1    26 full  .pred_mid_rand_forest_res_1_01   0    
##  6 rand_forest_res_1_01     1    26 full  .pred_full_rand_forest_res_1_01  0    
##  7 rand_forest_res_1_02     2    33 low   .pred_mid_rand_forest_res_1_02   0    
##  8 rand_forest_res_1_02     2    33 low   .pred_full_rand_forest_res_1_02  0    
##  9 rand_forest_res_1_02     2    33 mid   .pred_mid_rand_forest_res_1_02   9.22 
## 10 rand_forest_res_1_02     2    33 mid   .pred_full_rand_forest_res_1_02  5.38 
## # … with 50 more rows

该对象现在已准备好使用新数据进行预测!

tree_frogs_pred <-
  tree_frogs_test %>%
  bind_cols(predict(tree_frogs_model_st, ., type = "prob"))

计算模型的 ROC AUC:

yardstick::roc_auc(
  tree_frogs_pred,
  truth = reflex,
  contains(".pred_")
  )
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc hand_till      0.323

从每个集成成员中生成预测。

tree_frogs_pred <-
  tree_frogs_test %>%
  select(reflex) %>%
  bind_cols(
    predict(
      tree_frogs_model_st,
      tree_frogs_test,
      type = "class",
      members = TRUE
      )
    )

tree_frogs_pred
## # A tibble: 303 × 17
##    reflex .pred_class .pred_class_rand_forest… .pred_class_ran… .pred_class_ran…
##    <fct>  <fct>       <fct>                    <fct>            <fct>           
##  1 full   low         full                     full             full            
##  2 mid    full        low                      low              mid             
##  3 mid    full        mid                      mid              mid             
##  4 mid    full        low                      low              low             
##  5 full   low         full                     full             full            
##  6 full   low         full                     full             full            
##  7 full   low         full                     full             full            
##  8 full   low         full                     full             full            
##  9 full   low         full                     full             full            
## 10 full   low         full                     full             full            
## # … with 293 more rows, and 12 more variables:
## #   .pred_class_rand_forest_res_1_07 <fct>, .pred_class_nnet_res_1_04 <fct>,
## #   .pred_class_nnet_res_1_10 <fct>, .pred_class_nnet_res_1_05 <fct>,
## #   .pred_class_nnet_res_1_08 <fct>, .pred_class_nnet_res_1_07 <fct>,
## #   .pred_class_nnet_res_1_09 <fct>, .pred_class_rand_forest_res_1_10 <fct>,
## #   .pred_class_rand_forest_res_1_02 <fct>,
## #   .pred_class_rand_forest_res_1_09 <fct>, …
map_dfr(
  setNames(colnames(tree_frogs_pred), colnames(tree_frogs_pred)),
  ~mean(tree_frogs_pred$reflex == pull(tree_frogs_pred, .x))
) %>%
  pivot_longer(c(everything(), -reflex))
## # A tibble: 16 × 3
##    reflex name                             value
##     <dbl> <chr>                            <dbl>
##  1      1 .pred_class                      0    
##  2      1 .pred_class_rand_forest_res_1_05 0.835
##  3      1 .pred_class_rand_forest_res_1_01 0.845
##  4      1 .pred_class_rand_forest_res_1_04 0.861
##  5      1 .pred_class_rand_forest_res_1_07 0.875
##  6      1 .pred_class_nnet_res_1_04        0.558
##  7      1 .pred_class_nnet_res_1_10        0.558
##  8      1 .pred_class_nnet_res_1_05        0.558
##  9      1 .pred_class_nnet_res_1_08        0.558
## 10      1 .pred_class_nnet_res_1_07        0.558
## 11      1 .pred_class_nnet_res_1_09        0.558
## 12      1 .pred_class_rand_forest_res_1_10 0.871
## 13      1 .pred_class_rand_forest_res_1_02 0.878
## 14      1 .pred_class_rand_forest_res_1_09 0.884
## 15      1 .pred_class_rand_forest_res_1_06 0.871
## 16      1 .pred_class_rand_forest_res_1_08 0.881

核心函数

1.stacks() 初始化堆栈 2. add_candidates() 将模型定义添加到堆栈 3. blend_predictions() 从堆栈中确定系数 4. fit_members() 拟合堆叠模型

Finetune

finetune包含一些用于模型调整的额外功能。

所有函数

  1. control_race() 网格搜索
  2. control_sim_anneal() 模拟退火搜索过程
  3. plot_race() 绘制race结果
  4. tune_race_anova() 使用 ANOVA 模型通过race进行高效网格搜索
  5. tune_race_win_loss() 通过带有赢/输统计信息的race进行有效的网格搜索
  6. tune_sim_anneal() 通过模拟退火优化模型参数

usemodels

使用这个包可以快速的创建tidymodels 代码

 library(usemodels)
## Warning: package 'usemodels' was built under R version 4.1.2
 library(palmerpenguins)
 data(penguins)
 use_glmnet(body_mass_g ~ ., data = penguins)
## glmnet_recipe <- 
##   recipe(formula = body_mass_g ~ ., data = penguins) %>% 
##   step_novel(all_nominal_predictors()) %>% 
##   step_dummy(all_nominal_predictors()) %>% 
##   step_zv(all_predictors()) %>% 
##   step_normalize(all_numeric_predictors()) 
## 
## glmnet_spec <- 
##   linear_reg(penalty = tune(), mixture = tune()) %>% 
##   set_mode("regression") %>% 
##   set_engine("glmnet") 
## 
## glmnet_workflow <- 
##   workflow() %>% 
##   add_recipe(glmnet_recipe) %>% 
##   add_model(glmnet_spec) 
## 
## glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05, 
##     0.2, 0.4, 0.6, 0.8, 1)) 
## 
## glmnet_tune <- 
##   tune_grid(glmnet_workflow, resamples = stop("add your rsample object"), grid = glmnet_grid)

所有函数

  1. use_glmnet()
  2. use_xgboost()
  3. use_kknn()
  4. use_ranger()
  5. use_earth()
  6. use_cubist()
  7. use_kernlab_svm_rbf()
  8. use_kernlab_svm_poly()
  9. use_C5.0()

Functions to create boilerplate code for specific models

probably

THRESHOLDS 阈值

  1. threshold_perf() 跨概率阈值生成性能指标

创建类预测

  1. append_class_pred() 添加一class_pred列
  2. make_class_pred() make_two_class_pred() 从类概率创建class_pred向量

类预测

  1. class_pred() 创建类预测对象
  2. as_class_pred() 强制class_pred对象
  3. is_class_pred() 测试一个对象是否继承自class_pred
  4. reportable_rate()
    Calculate the reportable rate
  5. is_equivocal() which_equivocal() any_equivocal() 找出模棱两可的值

tidyposterior

用重采样和贝叶斯方法在模型之间进行正式的统计比较

重采样结果的贝叶斯分析

  1. perf_mod() 重采样统计的贝叶斯分析
  2. tidy() 提取模型的后验分布
  3. contrast_models() 估计模型之间的差异
  4. summary() 总结模型统计的后验分布
  5. summary() 总结模型差异的后验分布
  6. autoplot() autoplot() autoplot() 可视化模型统计的后验分布
  7. no_trans logit_trans Fisher_trans ln_trans inv_trans 简单的转换函数

shinymodels

探索模型 explore(

)

hardhat

hardhat 是一个以开发人员为中心的包,旨在简化新建模包的创建

butcher

建模管道R偶尔会导致拟合模型对象占用过多内存。 butcher可以轻松删除不再需要的拟合输出部分,而不会牺牲原始模型对象的很多功能。

为了充分利用您的内存,这个包提供了五个 S3 泛型供您删除模型对象的一部分:

  1. axe_call():删除调用对象。
  2. axe_ctrl():删除与培训相关的控制。
  3. axe_data():删除原始训练数据。
  4. axe_env(): 删除环境。
  5. axe_fitted():删除拟合值。

其他函数还包括

  1. weigh() 分析内存
  2. locate() 找到对象的一部分

appliable

appliable可以生成衡量外推的指标