Abstract
ナイーブベイズ分類を試す重回帰により, 医療費を予測する
# データ収集
insurance <- read.csv("insurance.csv", stringsAsFactors = T)
## 基本統計量
### 要約
summary(insurance)
## age sex bmi children smoker
## Min. :18.00 female:662 Min. :16.00 Min. :0.000 no :1064
## 1st Qu.:27.00 male :676 1st Qu.:26.30 1st Qu.:0.000 yes: 274
## Median :39.00 Median :30.40 Median :1.000
## Mean :39.21 Mean :30.67 Mean :1.095
## 3rd Qu.:51.00 3rd Qu.:34.70 3rd Qu.:2.000
## Max. :64.00 Max. :53.10 Max. :5.000
## region expenses
## northeast:324 Min. : 1122
## northwest:325 1st Qu.: 4740
## southeast:364 Median : 9382
## southwest:325 Mean :13270
## 3rd Qu.:16640
## Max. :63770
### ヒストグラム
ggplot(data = insurance, aes(x = expenses))+
geom_histogram(alpha = 0.8, fill = "lightblue", color = "gray")+
xlab("expenses")+
ggtitle("Hitogram of expenses")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
library(GGally)
## Registered S3 method overwritten by 'GGally':
## method from
## +.gg ggplot2
## 散布図と相関
columns <- c("age", "bmi", "children", "expenses")
ggpairs(data = insurance, columns = columns)
## 相関楕円とLoess曲線
library(psych)
##
## Attaching package: 'psych'
## The following objects are masked from 'package:ggplot2':
##
## %+%, alpha
pairs.panels(insurance[c("age", "bmi", "children", "expenses")])
相関楕円は楕円が引き伸ばされているほど相関が強い.
Loess曲線はどこにピークがあるかを示している.
# モデルを訓練する
lm_ins <- lm(formula = expenses ~ age + children + bmi + sex + smoker + region,
data = insurance)
lm_ins
##
## Call:
## lm(formula = expenses ~ age + children + bmi + sex + smoker +
## region, data = insurance)
##
## Coefficients:
## (Intercept) age children bmi
## -11941.6 256.8 475.7 339.3
## sexmale smokeryes regionnorthwest regionsoutheast
## -131.4 23847.5 -352.8 -1035.6
## regionsouthwest
## -959.3
喫煙者は医療費が高そう.
子供が増えると医療費は上がりそう.
肥満も医療費増に影響してそう.
# モデルの性能を評価する
summary(lm_ins)
##
## Call:
## lm(formula = expenses ~ age + children + bmi + sex + smoker +
## region, data = insurance)
##
## Residuals:
## Min 1Q Median 3Q Max
## -11302.7 -2850.9 -979.6 1383.9 29981.7
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -11941.6 987.8 -12.089 < 2e-16 ***
## age 256.8 11.9 21.586 < 2e-16 ***
## children 475.7 137.8 3.452 0.000574 ***
## bmi 339.3 28.6 11.864 < 2e-16 ***
## sexmale -131.3 332.9 -0.395 0.693255
## smokeryes 23847.5 413.1 57.723 < 2e-16 ***
## regionnorthwest -352.8 476.3 -0.741 0.458976
## regionsoutheast -1035.6 478.7 -2.163 0.030685 *
## regionsouthwest -959.3 477.9 -2.007 0.044921 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 6062 on 1329 degrees of freedom
## Multiple R-squared: 0.7509, Adjusted R-squared: 0.7494
## F-statistic: 500.9 on 8 and 1329 DF, p-value: < 2.2e-16
# モデルの性能を向上させる
insurance <- insurance %>%
## 年齢の2乗の変数を作る
mutate(age2 = age^2) %>%
## bmi>=30 = 1
mutate(bmi30 = if_else(bmi >= 30, 1, 0))
## 相互作用も追加する
lm_ins2 <- lm(formula = expenses ~ age + age2 + children + bmi + sex +
bmi30*smoker + region,
data = insurance)
summary(lm_ins2)
##
## Call:
## lm(formula = expenses ~ age + age2 + children + bmi + sex + bmi30 *
## smoker + region, data = insurance)
##
## Residuals:
## Min 1Q Median 3Q Max
## -17297.1 -1656.0 -1262.7 -727.8 24161.6
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 139.0053 1363.1359 0.102 0.918792
## age -32.6181 59.8250 -0.545 0.585690
## age2 3.7307 0.7463 4.999 6.54e-07 ***
## children 678.6017 105.8855 6.409 2.03e-10 ***
## bmi 119.7715 34.2796 3.494 0.000492 ***
## sexmale -496.7690 244.3713 -2.033 0.042267 *
## bmi30 -997.9355 422.9607 -2.359 0.018449 *
## smokeryes 13404.5952 439.9591 30.468 < 2e-16 ***
## regionnorthwest -279.1661 349.2826 -0.799 0.424285
## regionsoutheast -828.0345 351.6484 -2.355 0.018682 *
## regionsouthwest -1222.1619 350.5314 -3.487 0.000505 ***
## bmi30:smokeryes 19810.1534 604.6769 32.762 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4445 on 1326 degrees of freedom
## Multiple R-squared: 0.8664, Adjusted R-squared: 0.8653
## F-statistic: 781.7 on 11 and 1326 DF, p-value: < 2.2e-16
## 予測
insurance <- insurance %>%
mutate(pred = predict(lm_ins2))
cor(insurance$pred, insurance$expenses)
## [1] 0.9307999
正解率はかなり高いと思われる
ggplot(data = insurance)+
aes(x = pred, y = expenses)+
geom_point()+
geom_smooth(method = "lm", se = FALSE)
## `geom_smooth()` using formula 'y ~ x'
## 新しい加入者の医療費を予測する
new_dat <- expand.grid(
age = c(10, 20, 30, 40, 50, 60, 70, 80),
children = c(0, 1, 2, 3, 4, 5),
bmi = seq(from = 15, to = 55, length.out = 40),
sex = c("female", "male"),
smoker = c("no", "yes"),
region = c("northwest", "northeast", "southwest", "southeast")) %>%
mutate(age2 = age^2, bmi30 = if_else(bmi >= 30, 1, 0))
## 各モデルで予測値を入れた
pred <- predict(lm_ins2, new_dat)
new_dat_pred <- bind_cols(new_dat, pred) %>%
rename(prediction = 9)
## New names:
## * NA -> ...9
head(new_dat_pred)
## age children bmi sex smoker region age2 bmi30 prediction
## 1 10 0 15 female no northwest 100 0 1703.305
## 2 20 0 15 female no northwest 400 0 2496.347
## 3 30 0 15 female no northwest 900 0 4035.537
## 4 40 0 15 female no northwest 1600 0 6320.876
## 5 50 0 15 female no northwest 2500 0 9352.364
## 6 60 0 15 female no northwest 3600 0 13130.000
専門家によるワイン評価を回帰木とモデル木で模倣できるシステムを作る(ここでは白ワイン飲み).
データには酸度や糖度, 評価(0~10)などが格納されている
# データ収集と前処理
wine <- read.csv("whitewines.csv")
ggplot(data = wine)+
aes(x = quality)+
geom_histogram(fill = "blue", alpha = 0.5)+
scale_x_continuous(breaks = seq(3, 9, 1))
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
set.seed(1216)
train_sample <- sample(4898, 3750)
## 訓練データ(75%)
wine_train <- wine[train_sample, ]
## テストデータ(25%)
wine_test <- wine[-train_sample, ]
# モデルを訓練する
## 回帰木モデル
library(rpart)
m.rpart <- rpart(data = wine_train, formula = quality ~ .)
m.rpart
## n= 3750
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 3750 2979.43000 5.881600
## 2) alcohol< 10.85 2354 1420.22300 5.609601
## 4) volatile.acidity>=0.2575 1215 590.97450 5.357202 *
## 5) volatile.acidity< 0.2575 1139 669.28010 5.878841
## 10) volatile.acidity>=0.2075 574 270.05750 5.702091 *
## 11) volatile.acidity< 0.2075 565 363.07260 6.058407
## 22) density< 0.99788 467 258.55250 5.944325 *
## 23) density>=0.99788 98 69.47959 6.602041 *
## 3) alcohol>=10.85 1396 1091.37800 6.340258
## 6) free.sulfur.dioxide< 11.5 85 92.58824 5.411765 *
## 7) free.sulfur.dioxide>=11.5 1311 920.75970 6.400458
## 14) alcohol< 11.74167 634 448.95900 6.198738
## 28) volatile.acidity>=0.48 8 4.87500 4.125000 *
## 29) volatile.acidity< 0.48 626 409.24120 6.225240 *
## 15) alcohol>=11.74167 677 421.84340 6.589365 *
品質にとって, 一番重要な指標はアルコール度数とわかる
*がついているノードは葉ノードであり, そこで終わっていることを示す例えばアルコール<10.85かつvolatile.acidity>=0.2075だとqualityは5.702091とわかる
## 回帰木の可視化
library(rpart.plot)
## digitsで表示する数字の桁数を指定
rpart.plot(m.rpart, digits = 4)
## 図のアレンジ
rpart.plot(m.rpart, digits = 4, fallen.leaves = TRUE,
type = 3, extra = 101)
# モデルの性能を評価する
p.rpart <- predict(m.rpart, wine_test)
summary(p.rpart)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 4.125 5.357 5.702 5.860 6.225 6.602
summary(wine_test$quality)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 3.000 5.000 6.000 5.866 6.000 9.000
cor(p.rpart, wine_test$quality)
## [1] 0.5299503
予測値の範囲が実際の値より相当狭い→最も評価が高いなど極端なケースをうまくカバーできていなさそう. 相関では予測値が実測値からどの程度離れているかはわからない
予測値が正解からどの程度離れているのかを平均絶対誤差を使って調べる.\[MAE = \frac{1}{n}\sum_{i=1}^n|e_i|\]
MAE <- function(actual, predict){
mean(abs(actual - predict))
}
MAE(p.rpart, wine_test$quality)
## [1] 0.5986939
予測値と正解値の差は0.6くらいでそれほど精度は悪くなさそうだが, 一般的なスコアが5~6なので平均値しか予測しない分類器でもかなり性能がいいことになる
mean(wine_train$quality)
## [1] 5.8816
MAE(5.8816, wine_test$quality)
## [1] 0.6464948
平均値しかない予測よりかはマシだが, まだ改善の余地はありそう(もっと0に近づけれる?)
# モデルの性能を向上させる
## モデル木
library(Cubist)
## Loading required package: lattice
### 独立変数と従属変数を指定
m.cubist <- cubist(x = wine_train[-12], y = wine_train$quality)
m.cubist
##
## Call:
## cubist.default(x = wine_train[-12], y = wine_train$quality)
##
## Number of samples: 3750
## Number of predictors: 11
##
## Number of committees: 1
## Number of rules: 16
16個のルールを作成したとわかる
## 決定ルール(線形モデルによる予測)
summary(m.cubist)
##
## Call:
## cubist.default(x = wine_train[-12], y = wine_train$quality)
##
##
## Cubist [Release 2.07 GPL Edition] Fri Mar 26 01:36:13 2021
## ---------------------------------
##
## Target attribute `outcome'
##
## Read 3750 cases (12 attributes) from undefined.data
##
## Model:
##
## Rule 1: [47 cases, mean 5.2, range 3 to 7, est err 0.6]
##
## if
## volatile.acidity > 0.155
## total.sulfur.dioxide > 182
## sulphates > 0.64
## then
## outcome = 152.4 - 155 density + 0.07 residual.sugar + 0.202 alcohol
## + 1.04 pH - 1.27 volatile.acidity + 0.006 free.sulfur.dioxide
## + 0.11 fixed.acidity + 0.76 sulphates
## - 0.0018 total.sulfur.dioxide + 0.15 citric.acid
##
## Rule 2: [452 cases, mean 5.2, range 3 to 7, est err 0.4]
##
## if
## fixed.acidity <= 8.4
## volatile.acidity > 0.305
## pH <= 3.24
## sulphates <= 0.64
## alcohol <= 10.8
## then
## outcome = 13.9 + 0.303 alcohol + 0.0026 total.sulfur.dioxide
## + 0.1 fixed.acidity - 13 density + 0.006 residual.sugar
## + 0.13 sulphates - 0.13 volatile.acidity
##
## Rule 3: [58 cases, mean 5.3, range 3 to 7, est err 0.6]
##
## if
## fixed.acidity > 7.4
## volatile.acidity > 0.205
## residual.sugar <= 17.85
## alcohol <= 9.1
## then
## outcome = 126.3 - 3.83 volatile.acidity - 124 density
## + 0.42 fixed.acidity + 0.064 residual.sugar - 0.75 citric.acid
## + 0.012 alcohol
##
## Rule 4: [215 cases, mean 5.3, range 3 to 7, est err 0.5]
##
## if
## fixed.acidity <= 7.4
## volatile.acidity > 0.205
## density <= 0.99848
## alcohol <= 9.1
## then
## outcome = 300.7 - 305 density + 0.503 alcohol + 0.117 residual.sugar
## + 0.4 fixed.acidity + 1.26 sulphates - 1.35 volatile.acidity
##
## Rule 5: [63 cases, mean 5.3, range 5 to 6, est err 0.3]
##
## if
## volatile.acidity > 0.205
## residual.sugar > 17.85
## then
## outcome = 4.1 + 0.048 residual.sugar
##
## Rule 6: [124 cases, mean 5.3, range 3 to 9, est err 0.7]
##
## if
## fixed.acidity > 8.4
## volatile.acidity > 0.155
## total.sulfur.dioxide <= 208
## alcohol > 9.1
## then
## outcome = 3.6 - 2.14 volatile.acidity + 0.141 alcohol + 1.07 sulphates
## + 0.0044 free.sulfur.dioxide - 0.0017 total.sulfur.dioxide
## + 0.06 fixed.acidity
##
## Rule 7: [162 cases, mean 5.5, range 3 to 7, est err 0.5]
##
## if
## total.sulfur.dioxide > 208
## alcohol > 9.1
## then
## outcome = 37.8 - 0.0061 total.sulfur.dioxide - 2.51 volatile.acidity
## - 31 density - 0.0042 free.sulfur.dioxide
## + 0.013 residual.sugar + 0.17 pH + 0.03 fixed.acidity
## + 0.017 alcohol + 0.14 sulphates
##
## Rule 8: [460 cases, mean 5.7, range 3 to 8, est err 0.5]
##
## if
## fixed.acidity <= 8.4
## volatile.acidity <= 0.305
## free.sulfur.dioxide > 30
## total.sulfur.dioxide <= 208
## pH <= 3.24
## sulphates <= 0.64
## alcohol > 9.1
## alcohol <= 10.8
## then
## outcome = 3.9 - 2.24 volatile.acidity - 0.0046 total.sulfur.dioxide
## + 1.56 sulphates + 0.104 alcohol + 0.0075 free.sulfur.dioxide
## + 0.14 fixed.acidity
##
## Rule 9: [1454 cases, mean 5.8, range 3 to 9, est err 0.6]
##
## if
## free.sulfur.dioxide <= 30
## alcohol > 9.1
## then
## outcome = 118.9 + 0.0305 free.sulfur.dioxide + 0.072 residual.sugar
## - 120 density + 0.243 alcohol - 1.66 volatile.acidity
## + 0.83 pH + 0.43 citric.acid + 0.36 sulphates
##
## Rule 10: [69 cases, mean 5.9, range 5 to 7, est err 0.6]
##
## if
## fixed.acidity <= 7.4
## volatile.acidity > 0.205
## residual.sugar <= 17.85
## density > 0.99848
## alcohol <= 9.1
## then
## outcome = 69.8 + 0.322 alcohol - 2.74 volatile.acidity
## + 0.051 residual.sugar - 68 density + 0.19 fixed.acidity
## - 0.41 citric.acid
##
## Rule 11: [475 cases, mean 6.2, range 4 to 8, est err 0.6]
##
## if
## free.sulfur.dioxide > 30
## total.sulfur.dioxide <= 208
## pH > 3.24
## sulphates <= 0.64
## alcohol > 9.1
## then
## outcome = 459.5 - 463 density + 0.164 residual.sugar
## + 0.48 fixed.acidity - 0.228 alcohol + 1.44 pH
## + 0.98 sulphates - 0.26 volatile.acidity
## - 0.0003 total.sulfur.dioxide
##
## Rule 12: [192 cases, mean 6.3, range 4 to 8, est err 0.6]
##
## if
## volatile.acidity <= 0.155
## alcohol > 9.1
## then
## outcome = 280.9 - 283 density + 0.104 residual.sugar
## + 0.23 fixed.acidity + 1.25 pH - 1.15 volatile.acidity
## + 0.85 sulphates
##
## Rule 13: [118 cases, mean 6.3, range 5 to 8, est err 0.5]
##
## if
## volatile.acidity > 0.155
## free.sulfur.dioxide > 30
## total.sulfur.dioxide <= 182
## sulphates > 0.64
## alcohol > 9.1
## then
## outcome = 22.2 + 2.46 sulphates + 0.0143 free.sulfur.dioxide
## - 10.8 chlorides - 0.0043 total.sulfur.dioxide
## - 1.44 citric.acid + 0.138 alcohol - 19 density
## + 0.007 residual.sugar + 0.11 pH + 0.02 fixed.acidity
## - 0.09 volatile.acidity
##
## Rule 14: [29 cases, mean 6.3, range 5 to 8, est err 0.5]
##
## if
## volatile.acidity <= 0.205
## density > 0.9992
## alcohol <= 9.1
## then
## outcome = 42.4 - 1.897 alcohol - 0.0299 total.sulfur.dioxide
## + 12.29 volatile.acidity + 0.096 residual.sugar
## - 0.0248 free.sulfur.dioxide - 17 density + 0.04 fixed.acidity
## - 0.14 citric.acid + 0.7 chlorides
##
## Rule 15: [388 cases, mean 6.4, range 3 to 8, est err 0.6]
##
## if
## free.sulfur.dioxide > 30
## total.sulfur.dioxide <= 208
## pH <= 3.24
## sulphates <= 0.64
## alcohol > 10.8
## then
## outcome = 121.8 + 0.072 residual.sugar - 120 density + 0.226 alcohol
## + 1.66 sulphates - 7.3 chlorides
##
## Rule 16: [73 cases, mean 6.6, range 5 to 8, est err 0.4]
##
## if
## volatile.acidity <= 0.205
## density <= 0.9992
## alcohol <= 9.1
## then
## outcome = 8.5 + 10.38 volatile.acidity + 26.7 chlorides
## + 0.11 residual.sugar - 3.08 citric.acid + 0.22 fixed.acidity
## - 7 density
##
##
## Evaluation on training data (3750 cases):
##
## Average |error| 0.4
## Relative |error| 0.66
## Correlation coefficient 0.67
##
##
## Attribute usage:
## Conds Model
##
## 97% 93% alcohol
## 66% 55% free.sulfur.dioxide
## 44% 93% sulphates
## 43% 90% volatile.acidity
## 41% 56% pH
## 41% 43% total.sulfur.dioxide
## 31% 56% fixed.acidity
## 9% 85% density
## 4% 87% residual.sugar
## 42% citric.acid
## 14% chlorides
##
##
## Time: 0.3 secs
それぞれの条件での予測式を算出している例えば1つめは, volatile.acidity > 0.155 , total.sulfur.dioxide > 182, sulphates > 0.64という条件での回帰式である.
## 結果の予測
p.cubist <- predict(m.cubist, wine_test)
summary(p.cubist)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 4.178 5.372 5.830 5.807 6.232 7.607
cor(p.cubist, wine_test$quality)
## [1] 0.6121094
MAE(p.cubist, wine_test$quality)
## [1] 0.5243365
先ほどよりは性能があがったように思える