在回归、Logistic、LASSO、随机效应模型都能用的情况下,我们为什么还要学 决策树?主要有三点:
本笔记是我 自己造了两个数据集:
- 一个做回归树(模拟房价);
- 一个做分类树(模拟银行营销,正类比例大概 25%);
并且在每一段代码后面,都加了 “输出结果怎么读”
的解释。
library(rpart) # 决策树核心包
library(rpart.plot) # 更好看的树
library(dplyr) # 数据处理
library(ggplot2) # 画图
library(showtext)
font_add("Heiti TC Light", regular = "STHeiti Light.ttc")
showtext_auto()
我们虚构一个城市里有 500 个小区,每个小区都有以下特征:
rooms:平均房间数,越多越贵;dist_center:距离市中心的距离(公里),越远越便宜;lstat:低社会经济地位人群比例,越高说明小区更偏,越便宜;green:周边绿化覆盖率;crime:犯罪率;medv:我们想要预测的房价中位数(目标变量)。set.seed(1)
n_house <- 500
house_df <- data.frame(
rooms = round(runif(n_house, 3, 8), 2),
dist_center = round(runif(n_house, 0.5, 20), 2),
lstat = round(runif(n_house, 2, 30), 2),
green = round(runif(n_house, 0.1, 0.8), 2),
crime = round(rexp(n_house, rate = 1/5), 2) # 指数分布,凸显“部分区域犯罪率高”
)
# 人为设定生成房价的机制
house_df <- house_df %>%
mutate(
medv = 50 +
6 * rooms + # 房间数多 → 房价升
(-1.3) * dist_center + # 离市中心远 → 房价降
(-0.6) * lstat + # 低社会经济地位比例高 → 房价降
8 * green + # 有绿化 → 房价升
(-1.0) * crime + # 犯罪率高 → 房价降
rnorm(n_house, 0, 3) # 一点噪声,模拟现实里不可观测因素
)
str(house_df)
## 'data.frame': 500 obs. of 6 variables:
## $ rooms : num 4.33 4.86 5.86 7.54 4.01 7.49 7.72 6.3 6.15 3.31 ...
## $ dist_center: num 11.31 13.92 13.33 13.44 9.71 ...
## $ lstat : num 16.86 21.18 12.73 28.74 5.31 ...
## $ green : num 0.15 0.71 0.24 0.47 0.17 0.37 0.62 0.59 0.17 0.53 ...
## $ crime : num 3 3.75 0.83 11.03 0.28 ...
## $ medv : num 49.1 49.8 59.5 54.7 60.7 ...
summary(house_df$medv)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 19.33 49.76 58.37 58.95 68.79 97.18
输出
str() 会告诉你一共有 500 行、6
个变量,都是数值型;summary(medv) 一般会看到房价大部分集中在 40~80
左右,说明我们模拟得还算合理;set.seed(2025)
id_train <- sample(seq_len(n_house), size = floor(0.7 * n_house)) # 70% 训练
house_train <- house_df[id_train, ]
house_test <- house_df[-id_train, ]
我们保留 30% 不参与建模,只用来 考核模型在没见过的数据上的表现。这是判断树是不是“过拟合”的关键。
fit_house <- rpart(
medv ~ .,
data = house_train,
method = "anova" # 回归树也可以不写,它能识别出 medv 是连续的
)
fit_house
## n= 350
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 350 66635.2800 59.32867
## 2) rooms< 5.965 209 28288.5300 53.42257
## 4) dist_center>=11.17 84 7579.0200 45.68170
## 8) rooms< 4.365 41 2827.5120 39.75572
## 16) crime>=2.87 25 782.3538 34.75291 *
## 17) crime< 2.87 16 441.7969 47.57260 *
## 9) rooms>=4.365 43 1938.8620 51.33205 *
## 5) dist_center< 11.17 125 12293.7100 58.62444
## 10) crime>=12.24 13 748.0036 43.81439 *
## 11) crime< 12.24 112 8363.3540 60.34346
## 22) rooms< 4.155 44 2390.8070 55.18801 *
## 23) rooms>=4.155 68 4046.3770 63.67934
## 46) green< 0.61 52 2604.8940 61.34648
## 92) dist_center>=5.575 27 599.9490 56.75208 *
## 93) dist_center< 5.575 25 819.4909 66.30843 *
## 47) green>=0.61 16 238.7483 71.26114 *
## 3) rooms>=5.965 141 20250.1500 68.08311
## 6) dist_center>=11.265 67 5649.2670 60.27653
## 12) lstat>=16.56 27 1617.0430 53.88235 *
## 13) lstat< 16.56 40 2183.1730 64.59260
## 26) crime>=6.74 14 289.4671 57.79021 *
## 27) crime< 6.74 26 897.0673 68.25543 *
## 7) dist_center< 11.265 74 6820.8070 75.15122
## 14) lstat>=19.4 21 1482.3060 66.39956
## 28) crime>=2.78 12 636.7868 61.51844 *
## 29) crime< 2.78 9 178.4093 72.90773 *
## 15) lstat< 19.4 53 3092.7790 78.61886
## 30) dist_center>=6.22 27 770.6203 74.29564 *
## 31) dist_center< 6.22 26 1293.4730 83.10837 *
输出
root ...,表示根节点,也就是“所有训练样本”;n:这个节点里有多少个样本;deviance:回归问题里就是残差平方和;yval:如果停在这个节点上,模型给出的预测值,就是这个子样本里的
平均房价。* 的行,是
终节点,说明不用再往下分了。op <- par(no.readonly = TRUE)
par(mar = c(1, 1, 1, 1))
plot(fit_house, margin = 0.05)
text(fit_house, use.n = TRUE, cex = 0.7)
par(op)
读图
rooms 或者
dist_center 来分,说明这两个变量在这个模拟数据里
确实最有解释力。更好看的版本:
rpart.plot::prp(
fit_house,
type = 2, # 分裂条件放在内部
extra = 101, # 节点里显示预测值 + 节点样本数
fallen.leaves = TRUE
)
plotcp(fit_house)
fit_house$cptable
## CP nsplit rel error xerror xstd
## 1 0.27157694 0 1.0000000 1.0030137 0.07076405
## 2 0.12629647 1 0.7284231 0.7736633 0.05196276
## 3 0.11675604 2 0.6021266 0.7228256 0.05031177
## 4 0.04775778 3 0.4853706 0.5786681 0.04031116
## 5 0.04220956 4 0.4376128 0.5475991 0.03819379
## 6 0.03370170 5 0.3954032 0.5329764 0.03733510
## 7 0.02890616 6 0.3617015 0.4805375 0.03488561
## 8 0.02774882 7 0.3327954 0.4614401 0.03292669
## 9 0.02406175 8 0.3050465 0.4379398 0.03108787
## 10 0.01804953 9 0.2809848 0.4216290 0.03086947
## 11 0.01779018 10 0.2629353 0.4042117 0.03249943
## 12 0.01543755 11 0.2451451 0.3926860 0.03154869
## 13 0.01495662 12 0.2297075 0.3792125 0.03066014
## 14 0.01001136 13 0.2147509 0.3627746 0.02968993
## 15 0.01000000 14 0.2047396 0.3481261 0.02894511
plotcp()?
cptable?
nsplit:分裂次数,= 终节点数 - 1;rel error:在训练集上的相对误差;xerror:交叉验证误差(重点看它);xstd:交叉验证误差的标准误。min_cp <- fit_house$cptable[which.min(fit_house$cptable[, "xerror"]), "CP"]
min_cp
## [1] 0.01
fit_house_best <- prune(fit_house, cp = min_cp)
rpart.plot::prp(
fit_house_best,
type = 2,
extra = 101,
fallen.leaves = TRUE,
main = "修枝后的最优回归树"
)
这一步:我们不是要让树“长到最大”,而是要让树在“解释力”和“泛化能力”之间取得平衡。修枝之后的树,终节点通常会少一些,更稳。
pred_house <- predict(fit_house_best, newdata = house_test)
y_true <- house_test$medv
mse_tree <- mean((pred_house - y_true)^2)
mse_tree
## [1] 84.9063
MSE
画一下实际值 vs 预测值:
plot(pred_house, y_true,
xlab = "预测房价",
ylab = "实际房价",
main = "回归树预测效果(测试集)")
abline(0, 1, lwd = 2)
你会看到:点是“一坨一坨”的,而不是一条平滑的线,这是 树的典型特征 —— 每个终节点只能给一个数字。
lm_house <- lm(medv ~ ., data = house_train)
pred_lm <- predict(lm_house, newdata = house_test)
mse_lm <- mean((pred_lm - y_true)^2)
mse_tree
## [1] 84.9063
mse_lm
## [1] 12.00946
你会看到:mse_lm 比
mse_tree
小,这说明在我们这个模拟设定里,线性回归更贴合真实生成机制。
在本模拟数据中,回归树的可解释性更强,但预测精度不及OLS;如果希望提升树模型的预测性能,可进一步采用基于树的集成学习方法(如随机森林、梯度提升树等)。
set.seed(2)
n_bank <- 1500
job_levels <- c("admin", "blue", "services", "self", "retired")
bank_df <- data.frame(
age = round(runif(n_bank, 20, 65)),
job = factor(sample(job_levels, n_bank, replace = TRUE)),
marital = factor(sample(c("single", "married"), n_bank, replace = TRUE, prob = c(0.35, 0.65))),
last_contact_days = round(rexp(n_bank, rate = 1/15)),
emp_rate = round(rnorm(n_bank, mean = 1.2, sd = 0.6), 2)
)
## 拉强信号:最近联系、已婚、中年、经济好 → 更可能买
linpred <- -2.6 +
0.04 * (40 - abs(bank_df$age - 40)) +
ifelse(bank_df$marital == "married", 0.6, 0) +
ifelse(bank_df$last_contact_days < 10, 1.2, 0) +
0.4 * bank_df$emp_rate
prob <- 1 / (1 + exp(-linpred))
y <- rbinom(n_bank, size = 1, prob = prob)
bank_df$y <- factor(ifelse(y == 1, "yes", "no"))
prop.table(table(bank_df$y))
##
## no yes
## 0.5313333 0.4686667
输出
prop.table(table(bank_df$y))
会告诉你正负样本比例,比如:
no 大概 0.7yes 大概 0.3这就比原来 0.1 的例子好多了,树也更容易学到“yes”长什么样。
set.seed(2026)
id_train_b <- sample(seq_len(n_bank), size = 1000)
bank_train <- bank_df[id_train_b, ]
bank_test <- bank_df[-id_train_b, ]
fit_bank <- rpart(
y ~ .,
data = bank_train,
method = "class",
control = rpart.control(
cp = 0.001, # 比默认值 0.01 小得多
minsplit = 20,
minbucket = 7,
xval = 10
)
)
fit_bank
## n= 1000
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 1000 456 no (0.54400000 0.45600000)
## 2) last_contact_days>=9.5 535 165 no (0.69158879 0.30841121)
## 4) emp_rate< 0.845 140 24 no (0.82857143 0.17142857) *
## 5) emp_rate>=0.845 395 141 no (0.64303797 0.35696203)
## 10) marital=single 155 40 no (0.74193548 0.25806452)
## 20) job=admin,retired 72 12 no (0.83333333 0.16666667) *
## 21) job=blue,self,services 83 28 no (0.66265060 0.33734940)
## 42) emp_rate< 1.055 7 0 no (1.00000000 0.00000000) *
## 43) emp_rate>=1.055 76 28 no (0.63157895 0.36842105)
## 86) age< 37 26 6 no (0.76923077 0.23076923) *
## 87) age>=37 50 22 no (0.56000000 0.44000000)
## 174) age>=48 29 8 no (0.72413793 0.27586207) *
## 175) age< 48 21 7 yes (0.33333333 0.66666667) *
## 11) marital=married 240 101 no (0.57916667 0.42083333)
## 22) age< 24.5 23 2 no (0.91304348 0.08695652) *
## 23) age>=24.5 217 99 no (0.54377880 0.45622120)
## 46) age>=52.5 75 23 no (0.69333333 0.30666667)
## 92) emp_rate< 1.37 38 6 no (0.84210526 0.15789474) *
## 93) emp_rate>=1.37 37 17 no (0.54054054 0.45945946)
## 186) age< 56.5 10 1 no (0.90000000 0.10000000) *
## 187) age>=56.5 27 11 yes (0.40740741 0.59259259)
## 374) job=admin,blue,services 16 7 no (0.56250000 0.43750000) *
## 375) job=retired,self 11 2 yes (0.18181818 0.81818182) *
## 47) age< 52.5 142 66 yes (0.46478873 0.53521127)
## 94) emp_rate>=1.015 123 60 no (0.51219512 0.48780488)
## 188) emp_rate< 1.155 17 3 no (0.82352941 0.17647059) *
## 189) emp_rate>=1.155 106 49 yes (0.46226415 0.53773585)
## 378) age< 30.5 27 9 no (0.66666667 0.33333333) *
## 379) age>=30.5 79 31 yes (0.39240506 0.60759494)
## 758) last_contact_days< 25 48 23 yes (0.47916667 0.52083333)
## 1516) last_contact_days>=10.5 41 19 no (0.53658537 0.46341463)
## 3032) age>=43.5 18 4 no (0.77777778 0.22222222) *
## 3033) age< 43.5 23 8 yes (0.34782609 0.65217391)
## 6066) emp_rate< 1.485 8 2 no (0.75000000 0.25000000) *
## 6067) emp_rate>=1.485 15 2 yes (0.13333333 0.86666667) *
## 1517) last_contact_days< 10.5 7 1 yes (0.14285714 0.85714286) *
## 759) last_contact_days>=25 31 8 yes (0.25806452 0.74193548) *
## 95) emp_rate< 1.015 19 3 yes (0.15789474 0.84210526) *
## 3) last_contact_days< 9.5 465 174 yes (0.37419355 0.62580645)
## 6) age>=59.5 68 29 no (0.57352941 0.42647059)
## 12) marital=single 28 7 no (0.75000000 0.25000000) *
## 13) marital=married 40 18 yes (0.45000000 0.55000000)
## 26) job=admin,services 19 8 no (0.57894737 0.42105263) *
## 27) job=blue,retired,self 21 7 yes (0.33333333 0.66666667)
## 54) emp_rate>=1.655 7 3 no (0.57142857 0.42857143) *
## 55) emp_rate< 1.655 14 3 yes (0.21428571 0.78571429) *
## 7) age< 59.5 397 135 yes (0.34005038 0.65994962)
## 14) age< 28.5 104 48 yes (0.46153846 0.53846154)
## 28) job=retired 21 4 no (0.80952381 0.19047619) *
## 29) job=admin,blue,self,services 83 31 yes (0.37349398 0.62650602)
## 58) emp_rate< 0.485 11 3 no (0.72727273 0.27272727) *
## 59) emp_rate>=0.485 72 23 yes (0.31944444 0.68055556) *
## 15) age>=28.5 293 87 yes (0.29692833 0.70307167)
## 30) age>=49.5 96 38 yes (0.39583333 0.60416667)
## 60) last_contact_days< 2.5 33 14 no (0.57575758 0.42424242)
## 120) emp_rate< 1.33 17 4 no (0.76470588 0.23529412) *
## 121) emp_rate>=1.33 16 6 yes (0.37500000 0.62500000) *
## 61) last_contact_days>=2.5 63 19 yes (0.30158730 0.69841270)
## 122) emp_rate< 0.57 10 4 no (0.60000000 0.40000000) *
## 123) emp_rate>=0.57 53 13 yes (0.24528302 0.75471698) *
## 31) age< 49.5 197 49 yes (0.24873096 0.75126904)
## 62) age< 40.5 115 35 yes (0.30434783 0.69565217)
## 124) job=admin,self,services 72 27 yes (0.37500000 0.62500000)
## 248) emp_rate>=0.465 63 26 yes (0.41269841 0.58730159)
## 496) last_contact_days>=7.5 14 6 no (0.57142857 0.42857143) *
## 497) last_contact_days< 7.5 49 18 yes (0.36734694 0.63265306)
## 994) emp_rate< 0.885 13 6 no (0.53846154 0.46153846) *
## 995) emp_rate>=0.885 36 11 yes (0.30555556 0.69444444) *
## 249) emp_rate< 0.465 9 1 yes (0.11111111 0.88888889) *
## 125) job=blue,retired 43 8 yes (0.18604651 0.81395349) *
## 63) age>=40.5 82 14 yes (0.17073171 0.82926829)
## 126) emp_rate>=1.3 42 11 yes (0.26190476 0.73809524)
## 252) emp_rate< 1.385 7 3 no (0.57142857 0.42857143) *
## 253) emp_rate>=1.385 35 7 yes (0.20000000 0.80000000)
## 506) age>=46.5 7 3 no (0.57142857 0.42857143) *
## 507) age< 46.5 28 3 yes (0.10714286 0.89285714) *
## 127) emp_rate< 1.3 40 3 yes (0.07500000 0.92500000) *
输出
- 第一行还是根节点:会告诉你有 1000 个样本,多少是
no,多少是 yes;
- 如果你看到下面又出现了
last_contact_days < 某个值、marital = married
之类的分裂,说明我们前面造的“信号”被树找到了。
-
如果这里仍然只剩一个根节点,那就说明你这一次随机出来的数据恰好比较“弱”,可以把
cp 再降一点,比如 0.0005。
plotcp(fit_bank)
fit_bank$cptable
## CP nsplit rel error xerror xstd
## 1 0.256578947 0 1.0000000 1.0000000 0.03453958
## 2 0.021929825 1 0.7434211 0.7434211 0.03282734
## 3 0.014254386 2 0.7214912 0.7587719 0.03298846
## 4 0.010964912 4 0.6929825 0.7390351 0.03277984
## 5 0.008771930 5 0.6820175 0.7543860 0.03294324
## 6 0.006578947 6 0.6732456 0.7675439 0.03307698
## 7 0.005482456 7 0.6666667 0.7785088 0.03318403
## 8 0.004385965 23 0.5394737 0.7697368 0.03309871
## 9 0.003837719 25 0.5307018 0.7543860 0.03294324
## 10 0.002192982 29 0.5153509 0.7412281 0.03280368
## 11 0.001253133 30 0.5131579 0.7785088 0.03318403
## 12 0.001000000 38 0.5021930 0.7960526 0.03334711
和回归树一样,看哪一个 cp 的交叉验证误差最低:
min_cp_b <- fit_bank$cptable[which.min(fit_bank$cptable[, "xerror"]), "CP"]
min_cp_b
## [1] 0.01096491
fit_bank_best <- prune(fit_bank, cp = min_cp_b)
rpart.plot::prp(
fit_bank_best,
type = 2,
extra = 104,
fallen.leaves = TRUE,
main = "修枝后的最优分类树"
)
图
# 默认阈值 0.5
pred_cls <- predict(fit_bank_best, newdata = bank_test, type = "class")
tab <- table(pred_cls, bank_test$y)
tab
##
## pred_cls no yes
## no 183 123
## yes 70 124
accuracy <- sum(diag(tab)) / sum(tab)
accuracy
## [1] 0.614
# 防呆:有时候树会把所有人都判成一个类,下面这样写就不会报错
TP <- if ("yes" %in% rownames(tab)) tab["yes", "yes"] else 0
FN <- if ("no" %in% rownames(tab)) tab["no", "yes"] else 0
sensitivity <- TP / (TP + FN)
sensitivity
## [1] 0.5020243
解释结果
accuracy 是整体准确率;sensitivity 是对 “yes” 这一类的识别能力;# 取出预测概率
prob_test <- predict(fit_bank_best, newdata = bank_test, type = "prob")[, "yes"]
# 把阈值从0.5降到0.25
pred_025 <- ifelse(prob_test > 0.25, "yes", "no")
tab2 <- table(pred_025, bank_test$y)
tab2
##
## pred_025 no yes
## no 2 7
## yes 251 240
acc2 <- sum(diag(tab2)) / sum(tab2)
TP2 <- if ("yes" %in% rownames(tab2)) tab2["yes", "yes"] else 0
FN2 <- if ("no" %in% rownames(tab2)) tab2["no", "yes"] else 0
sens2 <- TP2 / (TP2 + FN2)
acc2
## [1] 0.484
sens2
## [1] 0.9716599
报告:
将分类阈值从0.5下调至0.25后,模型的总体准确率略有下降,但对“有购买意向”的识别率(灵敏度)明显提升,说明在该业务场景下通过牺牲部分准确率可换取更高的召回率,这与营销业务“宁可多打几个错的电话,也不能漏掉潜在客户”的目标是一致的。