policytree introduction
Binary treatment effect estimation and policy learning
n <- 10000 # 10000 行
p <- 10
X <- matrix(rnorm(n * p), n, p)
ee <- 1 / (1 + exp(X[, 3]))
ee %>% head()[1] 0.6392099 0.7619908 0.7124756 0.6863965 0.8472532 0.1378475
[1] -0.20507306 -0.17951938 -0.01481747 -0.14825344 0.06430986 0.21154274
W
0 1
4908 5092
Y <- X[, 3] + W * tt + rnorm(n)
# 模拟X,Y,Wï¼Œå› æžœæ£®æž—æ¨¡åž‹å»ºç«‹
cf <- causal_forest(X, Y, W)
plot(tt, predict(cf)$predictions) [,1] [,2] [,3] [,4] [,5] [,6]
[1,] 0.784426777 0.95872075 -0.5719365 -1.29523073 1.49015980 -0.3758897
[2,] 0.779628799 0.72349907 -1.1636251 1.43670437 -2.36100510 0.7750034
[3,] -0.017241292 0.13581579 -0.9074379 0.09009495 -1.02175957 -0.1619632
[4,] -0.004632896 1.22737455 -0.7833259 0.98047781 1.31845154 1.2058850
[5,] -0.451543128 -0.06580124 -1.7132178 -1.13824828 1.26972966 -0.1876438
[6,] -1.078862699 -0.72691445 1.8332843 1.22486612 0.02184471 -1.4344583
[,7] [,8] [,9] [,10]
[1,] -0.02572405 2.6317785 -0.04566386 -0.3197411
[2,] -2.26148341 0.4765856 2.29394108 -0.5321064
[3,] -0.39681674 0.4286189 1.21300794 0.3447469
[4,] 1.23718453 -1.1726293 1.82930248 0.6559452
[5,] 2.14352705 -1.5545898 0.37919106 -1.1356779
[6,] -1.85855945 -0.8198074 0.89547417 -1.9386003
policy_tree object
Tree depth: 2
Actions: 1: control 2: treated
Variable splits:
(1) split_variable: X2 split_value: 0.534778
(2) split_variable: X1 split_value: 0.627123
(4) * action: 2
(5) * action: 1
(3) split_variable: X1 split_value: -1.31902
(6) * action: 2
(7) * action: 1
Multi-action treatment effect estimation (Zhou, Athey and Wager, 2018)
The following example is from the 3-action DGP from section 6.4.1 in Zhou, Athey and Wager (2018)
action Y X.1 X.2 X.3 X.4
1 0 -1.2247222 0.75254947 0.8783117 0.3180328 0.07811779
2 2 3.7847432 0.21775649 0.9671132 0.7302644 0.26106458
3 1 2.8208815 0.75926056 0.7084679 0.8069856 0.87489642
4 0 0.2818341 0.01368435 0.3487733 0.2184962 0.12315543
5 0 -0.1698533 0.75518226 0.2770627 0.5429598 0.13204391
6 1 2.2243984 0.33342050 0.5498841 0.5284039 0.71499512
X <- data$X
Y <- data$Y
W <- data$action
multi.forest <- multi_causal_forest(X, Y, W)
# tau.hats:
head(predict(multi.forest)$predictions) 0 1 2
1 -2.1456331 0.28933957 2.02875155
2 -2.5025713 0.27356647 2.45900008
3 1.2483877 0.07367266 -1.15622199
4 -0.6198447 0.29013045 -0.04416951
5 -2.2018440 0.29136768 2.29718431
6 1.1851891 0.32668998 -1.13604132
# Each region with optimal action
region.pp <- data$region + 1
plot(X[, 5], X[, 7], col = region.pp)
leg <- sort(unique(region.pp))
legend("topleft", legend = leg - 1, col = leg, pch = 10)Policy learning
Cross-fitted Augmented Inverse Propensity Weighted Learning (CAIPWL) with the optimal depth 2 tree
0 1 2
1 -1.712859 0.6430332 1.8534440
2 -1.507405 0.1254383 8.2745133
3 2.723504 3.6378895 0.8491398
4 -3.802252 1.9665333 1.8141032
5 1.477444 0.3473808 1.6365853
6 2.646381 2.5034549 0.8105063
train <- sample(1:n, 9000)
opt.tree <- policy_tree(X[train, ], Gamma.matrix[train, ], depth = 2)
opt.treepolicy_tree object
Tree depth: 2
Actions: 1: 0 2: 1 3: 2
Variable splits:
(1) split_variable: X5 split_value: 0.602384
(2) split_variable: X7 split_value: 0.352353
(4) * action: 3
(5) * action: 1
(3) split_variable: X7 split_value: 0.744548
(6) * action: 2
(7) * action: 3
[1] 3 3 1 3 1 3
#> [1] 3 3 1 1 2 2
plot(X.test[, 5], X.test[, 7], col = pp)
leg <- sort(unique(pp))
legend("topleft", legend = leg - 1, col = leg, pch = 10)Efficient Policy Learning - Binary Treatment and Instrumental Variables (Wager and Athey, 2017)
The following example is from section 5.2 in Wager and Athey (2017).
W Z tau Y X.1 X.2
1 0 0 -0.2666420 0.5927068 -0.8102200 0.46671594
2 0 0 0.5412761 0.9463330 2.0825523 -0.07374747
3 0 1 -0.3644244 0.4684493 0.1018880 0.16926315
4 0 0 -0.5000000 -0.3173146 -1.0826638 -2.02495981
5 0 0 -0.2409265 2.3584731 -0.6996321 0.51814692
6 0 1 -0.4681670 0.5878648 -0.6293105 0.06366600
iv.forest <- grf::instrumental_forest(X = data$X, Y = data$Y, W = data$W, Z = data$Z)
gamma <- double_robust_scores(iv.forest)
head(gamma) control treated
[1,] 0.3631107 -0.3631107
[2,] 1.7595999 -1.7595999
[3,] -1.8003242 1.8003242
[4,] -1.8187049 1.8187049
[5,] 4.8398261 -4.8398261
[6,] -3.9837949 3.9837949
Find the depth-2 tree which solves (2):
policy_tree object
Tree depth: 2
Actions: 1: control 2: treated
Variable splits:
(1) split_variable: X1 split_value: -0.22769
(2) split_variable: X6 split_value: 0.580364
(4) * action: 2
(5) * action: 1
(3) split_variable: X1 split_value: 0.149089
(6) * action: 1
(7) * action: 2
Evaluate the policy on held out data:
[1] 1 0 1 1 1 0
[1] 0.06127335