Rでベイジアンネットワークメモ

Rでベイジアンネットワークやりかたメモです。

データとしては、MASS::birthwt データを使います。
これは、妊娠時の喫煙有無などの状況が低体重児につながるかどうかの調査データみたいです。

library(MASS)
library(data.table)
library(dplyr)
data <- birthwt %>% data.table()
data
##      low age lwt race smoke ptl ht ui ftv  bwt
##   1:   0  19 182    2     0   0  0  1   0 2523
##   2:   0  33 155    3     0   0  0  0   3 2551
##   3:   0  20 105    1     1   0  0  0   1 2557
##   4:   0  21 108    1     1   0  0  1   2 2594
##   5:   0  18 107    1     1   0  0  1   0 2600
##  ---                                          
## 185:   1  28  95    1     1   0  0  0   2 2466
## 186:   1  14 100    3     0   0  0  0   2 2495
## 187:   1  23  94    3     1   0  0  0   0 2495
## 188:   1  17 142    2     0   0  1  0   0 2495
## 189:   1  21 130    1     1   0  1  0   3 2495

このデータから、体重が低いかどうか(low)、人種(race)、妊娠時の喫煙の有無(smoke)、高血圧の病歴の有無(ht)、子宮の過敏性の有無(ui)の5つの変数を取り出し、ベイジアンネットワークを作成します。
ちなみに人種は、1=白人、2=黒人、3=その他です。

data <- data %>% select(low, race, smoke, ht, ui)
data
##      low race smoke ht ui
##   1:   0    2     0  0  1
##   2:   0    3     0  0  0
##   3:   0    1     1  0  0
##   4:   0    1     1  0  1
##   5:   0    1     1  0  1
##  ---                     
## 185:   1    1     1  0  0
## 186:   1    3     0  0  0
## 187:   1    3     1  0  0
## 188:   1    2     0  1  0
## 189:   1    1     1  1  0

さて、ベイジアンネットワークを作成するためには deal パッケージを使って次のように行います。

library(deal)
data[] <- lapply(data, as.factor)
pre.network <- network(data)
prior.dist <- jointprior(pre.network)
## Imaginary sample size: 96
update <- learn(pre.network, data, prior.dist)
post.network <- autosearch(getnetwork(update), data, prior.dist, trace=FALSE)
## [Autosearch (1) -619.5 [low][race|smoke][smoke][ht][ui]
## (2) -616.5 [low][race|smoke][smoke][ht|ui][ui]
## (3) -613.4 [low][race|smoke][smoke][ht|low:ui][ui]
## (4) -610.5 [low][race|smoke][smoke][ht|low:ui][ui|low]
## (5) -607.9 [low][race|low:smoke][smoke][ht|low:ui][ui|low]
## (6) -606.3 [low|smoke][race|low:smoke][smoke][ht|low:ui][ui|low]
## (7) -605.3 [low|smoke][race|low:smoke:ht][smoke][ht|low:ui][ui|low]
## ...(8) -605.3 [low|smoke][race|low:smoke:ht][smoke][ht|low:ui][ui|low:smoke]
## ....Total 0.39 add 0.18 rem 0.06 turn 0.03 sort 0.03 choose 0.02 rest 0.07 ]
plot(getnetwork(post.network))

plot of chunk unnamed-chunk-3

グラフができましたが、分かりにくいので変数名を日本語にしてやり直します。

setnames(data, "low", "低体重")
setnames(data, "race", "人種")
setnames(data, "smoke", "喫煙")
setnames(data, "ht", "高血圧")
setnames(data, "ui", "過敏性")
pre.network <- network(data)
prior.dist <- jointprior(pre.network)
## Imaginary sample size: 96
update <- learn(pre.network, data, prior.dist)
post.network <- autosearch(getnetwork(update), data, prior.dist, trace=FALSE)
## [Autosearch (1) -619.5 [低体重][人種|喫煙][喫煙][高血圧][過敏性]
## (2) -616.5 [低体重][人種|喫煙][喫煙][高血圧|過敏性][過敏性]
## (3) -613.4 [低体重][人種|喫煙][喫煙][高血圧|低体重:過敏性][過敏性]
## (4) -610.5 [低体重][人種|喫煙][喫煙][高血圧|低体重:過敏性][過敏性|低体重]
## (5) -607.9 [低体重][人種|低体重:喫煙][喫煙][高血圧|低体重:過敏性][過敏性|低体重]
## (6) -606.3 [低体重|喫煙][人種|低体重:喫煙][喫煙][高血圧|低体重:過敏性][過敏性|低体重]
## (7) -605.3 [低体重|喫煙][人種|低体重:喫煙:高血圧][喫煙][高血圧|低体重:過敏性][過敏性|低体重]
## ...(8) -605.3 [低体重|喫煙][人種|低体重:喫煙:高血圧][喫煙][高血圧|低体重:過敏性][過敏性|低体重:喫煙]
## ....Total 0.39 add 0.15 rem 0.04 turn 0.07 sort 0 choose 0 rest 0.13 ]
plot(getnetwork(post.network))

plot of chunk unnamed-chunk-4

これ見ると、低体重であることが高血圧の原因になっていますが、低体重は結果変数なので、原因にはなりえません。また、喫煙の有無が人種の原因になっていますが、これもおかしいですね。

というわけで、ベイジアンネットワークを作成するときに、ここからここへは辺をつなげないという制約を加えることができます。

ここでは、「低体重はどの変数の原因にもならない」と「人種はどの変数の結果にもならない」という制約を置いてみましょう。

低体重のノード番号は1、人種のノード番号は2なので、

pre.network <- getnetwork(update)
banlist(pre.network) <- rbind(cbind(1, 2:5), cbind(1:5, 2))
post.network <- autosearch(pre.network, data, prior.dist, trace=FALSE)
## [Autosearch (1) -619.5 [低体重][人種][喫煙|人種][高血圧][過敏性]
## (2) -616.5 [低体重][人種][喫煙|人種][高血圧|過敏性][過敏性]
## (3) -613.5 [低体重|過敏性][人種][喫煙|人種][高血圧|過敏性][過敏性]
## (4) -610.5 [低体重|高血圧:過敏性][人種][喫煙|人種][高血圧|過敏性][過敏性]
## (5) -608.9 [低体重|高血圧:過敏性][人種][喫煙|人種][高血圧|人種:過敏性][過敏性]
## (6) -607.8 [低体重|高血圧:過敏性][人種][喫煙|人種:過敏性][高血圧|人種:過敏性][過敏性]
## (7) -607 [低体重|喫煙:高血圧:過敏性][人種][喫煙|人種:過敏性][高血圧|人種:過敏性][過敏性]
## (8) -606.1 [低体重|喫煙:高血圧:過敏性][人種][喫煙|人種:過敏性][高血圧|人種][過敏性|高血圧]
## .Total 0.24 add 0.1 rem 0.04 turn 0.04 sort 0 choose 0 rest 0.06 ]
plot(getnetwork(post.network), showban=FALSE)

plot of chunk unnamed-chunk-5

localprob(getnetwork(post.network))
## $低体重
## , , 0, 0
## 
##        0      1
## 0 0.7714 0.6176
## 1 0.2286 0.3824
## 
## , , 1, 0
## 
##        0      1
## 0 0.4737 0.4706
## 1 0.5263 0.5294
## 
## , , 0, 1
## 
##        0      1
## 0 0.5185 0.4800
## 1 0.4815 0.5200
## 
## , , 1, 1
## 
##        0      1
## 0 0.5000 0.5000
## 1 0.5000 0.5000
## 
## 
## $人種
## 
##      1      2      3 
## 0.4491 0.2035 0.3474 
## 
## $喫煙
## , , 0
## 
##        0      1
## 1 0.3871 0.6000
## 2 0.1694 0.2118
## 3 0.4435 0.1882
## 
## , , 1
## 
##        0      1
## 1 0.3077 0.4595
## 2 0.2821 0.2162
## 3 0.4103 0.3243
## 
## 
## $高血圧
##        0      1
## 1 0.4756 0.3500
## 2 0.1733 0.3167
## 3 0.3511 0.3333
## 
## $過敏性
##        0      1
## 0 0.8278 0.6842
## 1 0.1722 0.3158