線形回帰

重回帰により, 医療費を予測する

# データ収集
insurance <- read.csv("insurance.csv", stringsAsFactors = T)
## 基本統計量
### 要約
summary(insurance)
##       age            sex           bmi           children     smoker    
##  Min.   :18.00   female:662   Min.   :16.00   Min.   :0.000   no :1064  
##  1st Qu.:27.00   male  :676   1st Qu.:26.30   1st Qu.:0.000   yes: 274  
##  Median :39.00                Median :30.40   Median :1.000             
##  Mean   :39.21                Mean   :30.67   Mean   :1.095             
##  3rd Qu.:51.00                3rd Qu.:34.70   3rd Qu.:2.000             
##  Max.   :64.00                Max.   :53.10   Max.   :5.000             
##        region       expenses    
##  northeast:324   Min.   : 1122  
##  northwest:325   1st Qu.: 4740  
##  southeast:364   Median : 9382  
##  southwest:325   Mean   :13270  
##                  3rd Qu.:16640  
##                  Max.   :63770
### ヒストグラム
ggplot(data = insurance, aes(x = expenses))+
  geom_histogram(alpha = 0.8, fill = "lightblue", color = "gray")+
  xlab("expenses")+
  ggtitle("Hitogram of expenses")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

library(GGally)
## Registered S3 method overwritten by 'GGally':
##   method from   
##   +.gg   ggplot2
## 散布図と相関
columns <- c("age", "bmi", "children", "expenses")
ggpairs(data = insurance, columns = columns)

## 相関楕円とLoess曲線
library(psych)
## 
## Attaching package: 'psych'
## The following objects are masked from 'package:ggplot2':
## 
##     %+%, alpha
pairs.panels(insurance[c("age", "bmi", "children", "expenses")])

  • 相関楕円は楕円が引き伸ばされているほど相関が強い.

  • Loess曲線はどこにピークがあるかを示している.

# モデルを訓練する
lm_ins <- lm(formula = expenses ~ age + children + bmi + sex + smoker + region,
             data = insurance)
lm_ins
## 
## Call:
## lm(formula = expenses ~ age + children + bmi + sex + smoker + 
##     region, data = insurance)
## 
## Coefficients:
##     (Intercept)              age         children              bmi  
##        -11941.6            256.8            475.7            339.3  
##         sexmale        smokeryes  regionnorthwest  regionsoutheast  
##          -131.4          23847.5           -352.8          -1035.6  
## regionsouthwest  
##          -959.3
  • 喫煙者は医療費が高そう.

  • 子供が増えると医療費は上がりそう.

  • 肥満も医療費増に影響してそう.

# モデルの性能を評価する
summary(lm_ins)
## 
## Call:
## lm(formula = expenses ~ age + children + bmi + sex + smoker + 
##     region, data = insurance)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -11302.7  -2850.9   -979.6   1383.9  29981.7 
## 
## Coefficients:
##                 Estimate Std. Error t value Pr(>|t|)    
## (Intercept)     -11941.6      987.8 -12.089  < 2e-16 ***
## age                256.8       11.9  21.586  < 2e-16 ***
## children           475.7      137.8   3.452 0.000574 ***
## bmi                339.3       28.6  11.864  < 2e-16 ***
## sexmale           -131.3      332.9  -0.395 0.693255    
## smokeryes        23847.5      413.1  57.723  < 2e-16 ***
## regionnorthwest   -352.8      476.3  -0.741 0.458976    
## regionsoutheast  -1035.6      478.7  -2.163 0.030685 *  
## regionsouthwest   -959.3      477.9  -2.007 0.044921 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 6062 on 1329 degrees of freedom
## Multiple R-squared:  0.7509, Adjusted R-squared:  0.7494 
## F-statistic: 500.9 on 8 and 1329 DF,  p-value: < 2.2e-16
# モデルの性能を向上させる
insurance <- insurance %>% 
  ## 年齢の2乗の変数を作る
  mutate(age2 = age^2) %>% 
  ## bmi>=30 = 1
  mutate(bmi30 = if_else(bmi >= 30, 1, 0))

## 相互作用も追加する
lm_ins2 <- lm(formula = expenses ~ age + age2 + children + bmi + sex +
                bmi30*smoker + region,
              data = insurance)
summary(lm_ins2)
## 
## Call:
## lm(formula = expenses ~ age + age2 + children + bmi + sex + bmi30 * 
##     smoker + region, data = insurance)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -17297.1  -1656.0  -1262.7   -727.8  24161.6 
## 
## Coefficients:
##                   Estimate Std. Error t value Pr(>|t|)    
## (Intercept)       139.0053  1363.1359   0.102 0.918792    
## age               -32.6181    59.8250  -0.545 0.585690    
## age2                3.7307     0.7463   4.999 6.54e-07 ***
## children          678.6017   105.8855   6.409 2.03e-10 ***
## bmi               119.7715    34.2796   3.494 0.000492 ***
## sexmale          -496.7690   244.3713  -2.033 0.042267 *  
## bmi30            -997.9355   422.9607  -2.359 0.018449 *  
## smokeryes       13404.5952   439.9591  30.468  < 2e-16 ***
## regionnorthwest  -279.1661   349.2826  -0.799 0.424285    
## regionsoutheast  -828.0345   351.6484  -2.355 0.018682 *  
## regionsouthwest -1222.1619   350.5314  -3.487 0.000505 ***
## bmi30:smokeryes 19810.1534   604.6769  32.762  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4445 on 1326 degrees of freedom
## Multiple R-squared:  0.8664, Adjusted R-squared:  0.8653 
## F-statistic: 781.7 on 11 and 1326 DF,  p-value: < 2.2e-16
## 予測
insurance <- insurance %>% 
  mutate(pred = predict(lm_ins2))

cor(insurance$pred, insurance$expenses)
## [1] 0.9307999

正解率はかなり高いと思われる

ggplot(data = insurance)+
  aes(x = pred, y = expenses)+
  geom_point()+
  geom_smooth(method = "lm", se = FALSE)
## `geom_smooth()` using formula 'y ~ x'

## 新しい加入者の医療費を予測する
new_dat <- expand.grid(
  age = c(10, 20, 30, 40, 50, 60, 70, 80),
  children = c(0, 1, 2, 3, 4, 5),
  bmi = seq(from = 15, to = 55, length.out = 40),
  sex = c("female", "male"),
  smoker = c("no", "yes"),
  region = c("northwest", "northeast", "southwest", "southeast")) %>% 
  mutate(age2 = age^2, bmi30 = if_else(bmi >= 30, 1, 0))

## 各モデルで予測値を入れた
pred <- predict(lm_ins2, new_dat)
new_dat_pred <- bind_cols(new_dat, pred) %>% 
  rename(prediction = 9)
## New names:
## * NA -> ...9
head(new_dat_pred)
##   age children bmi    sex smoker    region age2 bmi30 prediction
## 1  10        0  15 female     no northwest  100     0   1703.305
## 2  20        0  15 female     no northwest  400     0   2496.347
## 3  30        0  15 female     no northwest  900     0   4035.537
## 4  40        0  15 female     no northwest 1600     0   6320.876
## 5  50        0  15 female     no northwest 2500     0   9352.364
## 6  60        0  15 female     no northwest 3600     0  13130.000

回帰木とモデル木を理解する

専門家によるワイン評価を回帰木とモデル木で模倣できるシステムを作る(ここでは白ワイン飲み).
データには酸度や糖度, 評価(0~10)などが格納されている

# データ収集と前処理
wine <- read.csv("whitewines.csv")
ggplot(data = wine)+
  aes(x = quality)+
  geom_histogram(fill = "blue", alpha = 0.5)+
  scale_x_continuous(breaks = seq(3, 9, 1))
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

set.seed(1216)

train_sample <- sample(4898, 3750)
## 訓練データ(75%)
wine_train <- wine[train_sample, ]
## テストデータ(25%)
wine_test <- wine[-train_sample, ]
# モデルを訓練する
## 回帰木モデル
library(rpart)
m.rpart <- rpart(data = wine_train, formula = quality ~ .)
m.rpart
## n= 3750 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 3750 2979.43000 5.881600  
##    2) alcohol< 10.85 2354 1420.22300 5.609601  
##      4) volatile.acidity>=0.2575 1215  590.97450 5.357202 *
##      5) volatile.acidity< 0.2575 1139  669.28010 5.878841  
##       10) volatile.acidity>=0.2075 574  270.05750 5.702091 *
##       11) volatile.acidity< 0.2075 565  363.07260 6.058407  
##         22) density< 0.99788 467  258.55250 5.944325 *
##         23) density>=0.99788 98   69.47959 6.602041 *
##    3) alcohol>=10.85 1396 1091.37800 6.340258  
##      6) free.sulfur.dioxide< 11.5 85   92.58824 5.411765 *
##      7) free.sulfur.dioxide>=11.5 1311  920.75970 6.400458  
##       14) alcohol< 11.74167 634  448.95900 6.198738  
##         28) volatile.acidity>=0.48 8    4.87500 4.125000 *
##         29) volatile.acidity< 0.48 626  409.24120 6.225240 *
##       15) alcohol>=11.74167 677  421.84340 6.589365 *

品質にとって, 一番重要な指標はアルコール度数とわかる
*がついているノードは葉ノードであり, そこで終わっていることを示す例えばアルコール<10.85かつvolatile.acidity>=0.2075だとqualityは5.702091とわかる

## 回帰木の可視化
library(rpart.plot)
## digitsで表示する数字の桁数を指定
rpart.plot(m.rpart, digits = 4)

## 図のアレンジ
rpart.plot(m.rpart, digits = 4, fallen.leaves = TRUE,
           type = 3, extra = 101)

# モデルの性能を評価する
p.rpart <- predict(m.rpart, wine_test)
summary(p.rpart)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   4.125   5.357   5.702   5.860   6.225   6.602
summary(wine_test$quality)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   3.000   5.000   6.000   5.866   6.000   9.000
cor(p.rpart, wine_test$quality)
## [1] 0.5299503

予測値の範囲が実際の値より相当狭い→最も評価が高いなど極端なケースをうまくカバーできていなさそう. 相関では予測値が実測値からどの程度離れているかはわからない

予測値が正解からどの程度離れているのかを平均絶対誤差を使って調べる.\[MAE = \frac{1}{n}\sum_{i=1}^n|e_i|\]

MAE <- function(actual, predict){
  mean(abs(actual - predict))
}

MAE(p.rpart, wine_test$quality)
## [1] 0.5986939

予測値と正解値の差は0.6くらいでそれほど精度は悪くなさそうだが, 一般的なスコアが5~6なので平均値しか予測しない分類器でもかなり性能がいいことになる

mean(wine_train$quality)
## [1] 5.8816
MAE(5.8816, wine_test$quality)
## [1] 0.6464948

平均値しかない予測よりかはマシだが, まだ改善の余地はありそう(もっと0に近づけれる?)

# モデルの性能を向上させる
## モデル木
library(Cubist)
## Loading required package: lattice
### 独立変数と従属変数を指定
m.cubist <- cubist(x = wine_train[-12], y = wine_train$quality)
m.cubist
## 
## Call:
## cubist.default(x = wine_train[-12], y = wine_train$quality)
## 
## Number of samples: 3750 
## Number of predictors: 11 
## 
## Number of committees: 1 
## Number of rules: 16

16個のルールを作成したとわかる

## 決定ルール(線形モデルによる予測)
summary(m.cubist)
## 
## Call:
## cubist.default(x = wine_train[-12], y = wine_train$quality)
## 
## 
## Cubist [Release 2.07 GPL Edition]  Fri Mar 26 01:36:13 2021
## ---------------------------------
## 
##     Target attribute `outcome'
## 
## Read 3750 cases (12 attributes) from undefined.data
## 
## Model:
## 
##   Rule 1: [47 cases, mean 5.2, range 3 to 7, est err 0.6]
## 
##     if
##  volatile.acidity > 0.155
##  total.sulfur.dioxide > 182
##  sulphates > 0.64
##     then
##  outcome = 152.4 - 155 density + 0.07 residual.sugar + 0.202 alcohol
##            + 1.04 pH - 1.27 volatile.acidity + 0.006 free.sulfur.dioxide
##            + 0.11 fixed.acidity + 0.76 sulphates
##            - 0.0018 total.sulfur.dioxide + 0.15 citric.acid
## 
##   Rule 2: [452 cases, mean 5.2, range 3 to 7, est err 0.4]
## 
##     if
##  fixed.acidity <= 8.4
##  volatile.acidity > 0.305
##  pH <= 3.24
##  sulphates <= 0.64
##  alcohol <= 10.8
##     then
##  outcome = 13.9 + 0.303 alcohol + 0.0026 total.sulfur.dioxide
##            + 0.1 fixed.acidity - 13 density + 0.006 residual.sugar
##            + 0.13 sulphates - 0.13 volatile.acidity
## 
##   Rule 3: [58 cases, mean 5.3, range 3 to 7, est err 0.6]
## 
##     if
##  fixed.acidity > 7.4
##  volatile.acidity > 0.205
##  residual.sugar <= 17.85
##  alcohol <= 9.1
##     then
##  outcome = 126.3 - 3.83 volatile.acidity - 124 density
##            + 0.42 fixed.acidity + 0.064 residual.sugar - 0.75 citric.acid
##            + 0.012 alcohol
## 
##   Rule 4: [215 cases, mean 5.3, range 3 to 7, est err 0.5]
## 
##     if
##  fixed.acidity <= 7.4
##  volatile.acidity > 0.205
##  density <= 0.99848
##  alcohol <= 9.1
##     then
##  outcome = 300.7 - 305 density + 0.503 alcohol + 0.117 residual.sugar
##            + 0.4 fixed.acidity + 1.26 sulphates - 1.35 volatile.acidity
## 
##   Rule 5: [63 cases, mean 5.3, range 5 to 6, est err 0.3]
## 
##     if
##  volatile.acidity > 0.205
##  residual.sugar > 17.85
##     then
##  outcome = 4.1 + 0.048 residual.sugar
## 
##   Rule 6: [124 cases, mean 5.3, range 3 to 9, est err 0.7]
## 
##     if
##  fixed.acidity > 8.4
##  volatile.acidity > 0.155
##  total.sulfur.dioxide <= 208
##  alcohol > 9.1
##     then
##  outcome = 3.6 - 2.14 volatile.acidity + 0.141 alcohol + 1.07 sulphates
##            + 0.0044 free.sulfur.dioxide - 0.0017 total.sulfur.dioxide
##            + 0.06 fixed.acidity
## 
##   Rule 7: [162 cases, mean 5.5, range 3 to 7, est err 0.5]
## 
##     if
##  total.sulfur.dioxide > 208
##  alcohol > 9.1
##     then
##  outcome = 37.8 - 0.0061 total.sulfur.dioxide - 2.51 volatile.acidity
##            - 31 density - 0.0042 free.sulfur.dioxide
##            + 0.013 residual.sugar + 0.17 pH + 0.03 fixed.acidity
##            + 0.017 alcohol + 0.14 sulphates
## 
##   Rule 8: [460 cases, mean 5.7, range 3 to 8, est err 0.5]
## 
##     if
##  fixed.acidity <= 8.4
##  volatile.acidity <= 0.305
##  free.sulfur.dioxide > 30
##  total.sulfur.dioxide <= 208
##  pH <= 3.24
##  sulphates <= 0.64
##  alcohol > 9.1
##  alcohol <= 10.8
##     then
##  outcome = 3.9 - 2.24 volatile.acidity - 0.0046 total.sulfur.dioxide
##            + 1.56 sulphates + 0.104 alcohol + 0.0075 free.sulfur.dioxide
##            + 0.14 fixed.acidity
## 
##   Rule 9: [1454 cases, mean 5.8, range 3 to 9, est err 0.6]
## 
##     if
##  free.sulfur.dioxide <= 30
##  alcohol > 9.1
##     then
##  outcome = 118.9 + 0.0305 free.sulfur.dioxide + 0.072 residual.sugar
##            - 120 density + 0.243 alcohol - 1.66 volatile.acidity
##            + 0.83 pH + 0.43 citric.acid + 0.36 sulphates
## 
##   Rule 10: [69 cases, mean 5.9, range 5 to 7, est err 0.6]
## 
##     if
##  fixed.acidity <= 7.4
##  volatile.acidity > 0.205
##  residual.sugar <= 17.85
##  density > 0.99848
##  alcohol <= 9.1
##     then
##  outcome = 69.8 + 0.322 alcohol - 2.74 volatile.acidity
##            + 0.051 residual.sugar - 68 density + 0.19 fixed.acidity
##            - 0.41 citric.acid
## 
##   Rule 11: [475 cases, mean 6.2, range 4 to 8, est err 0.6]
## 
##     if
##  free.sulfur.dioxide > 30
##  total.sulfur.dioxide <= 208
##  pH > 3.24
##  sulphates <= 0.64
##  alcohol > 9.1
##     then
##  outcome = 459.5 - 463 density + 0.164 residual.sugar
##            + 0.48 fixed.acidity - 0.228 alcohol + 1.44 pH
##            + 0.98 sulphates - 0.26 volatile.acidity
##            - 0.0003 total.sulfur.dioxide
## 
##   Rule 12: [192 cases, mean 6.3, range 4 to 8, est err 0.6]
## 
##     if
##  volatile.acidity <= 0.155
##  alcohol > 9.1
##     then
##  outcome = 280.9 - 283 density + 0.104 residual.sugar
##            + 0.23 fixed.acidity + 1.25 pH - 1.15 volatile.acidity
##            + 0.85 sulphates
## 
##   Rule 13: [118 cases, mean 6.3, range 5 to 8, est err 0.5]
## 
##     if
##  volatile.acidity > 0.155
##  free.sulfur.dioxide > 30
##  total.sulfur.dioxide <= 182
##  sulphates > 0.64
##  alcohol > 9.1
##     then
##  outcome = 22.2 + 2.46 sulphates + 0.0143 free.sulfur.dioxide
##            - 10.8 chlorides - 0.0043 total.sulfur.dioxide
##            - 1.44 citric.acid + 0.138 alcohol - 19 density
##            + 0.007 residual.sugar + 0.11 pH + 0.02 fixed.acidity
##            - 0.09 volatile.acidity
## 
##   Rule 14: [29 cases, mean 6.3, range 5 to 8, est err 0.5]
## 
##     if
##  volatile.acidity <= 0.205
##  density > 0.9992
##  alcohol <= 9.1
##     then
##  outcome = 42.4 - 1.897 alcohol - 0.0299 total.sulfur.dioxide
##            + 12.29 volatile.acidity + 0.096 residual.sugar
##            - 0.0248 free.sulfur.dioxide - 17 density + 0.04 fixed.acidity
##            - 0.14 citric.acid + 0.7 chlorides
## 
##   Rule 15: [388 cases, mean 6.4, range 3 to 8, est err 0.6]
## 
##     if
##  free.sulfur.dioxide > 30
##  total.sulfur.dioxide <= 208
##  pH <= 3.24
##  sulphates <= 0.64
##  alcohol > 10.8
##     then
##  outcome = 121.8 + 0.072 residual.sugar - 120 density + 0.226 alcohol
##            + 1.66 sulphates - 7.3 chlorides
## 
##   Rule 16: [73 cases, mean 6.6, range 5 to 8, est err 0.4]
## 
##     if
##  volatile.acidity <= 0.205
##  density <= 0.9992
##  alcohol <= 9.1
##     then
##  outcome = 8.5 + 10.38 volatile.acidity + 26.7 chlorides
##            + 0.11 residual.sugar - 3.08 citric.acid + 0.22 fixed.acidity
##            - 7 density
## 
## 
## Evaluation on training data (3750 cases):
## 
##     Average  |error|                0.4
##     Relative |error|               0.66
##     Correlation coefficient        0.67
## 
## 
##  Attribute usage:
##    Conds  Model
## 
##     97%    93%    alcohol
##     66%    55%    free.sulfur.dioxide
##     44%    93%    sulphates
##     43%    90%    volatile.acidity
##     41%    56%    pH
##     41%    43%    total.sulfur.dioxide
##     31%    56%    fixed.acidity
##      9%    85%    density
##      4%    87%    residual.sugar
##            42%    citric.acid
##            14%    chlorides
## 
## 
## Time: 0.3 secs

それぞれの条件での予測式を算出している例えば1つめは, volatile.acidity > 0.155 , total.sulfur.dioxide > 182, sulphates > 0.64という条件での回帰式である.

## 結果の予測
p.cubist <- predict(m.cubist, wine_test)
summary(p.cubist)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   4.178   5.372   5.830   5.807   6.232   7.607
cor(p.cubist, wine_test$quality)
## [1] 0.6121094
MAE(p.cubist, wine_test$quality)
## [1] 0.5243365

先ほどよりは性能があがったように思える