1 前言:为什么还要看“树”?

在回归、Logistic、LASSO、随机效应模型都能用的情况下,我们为什么还要学 决策树?主要有三点:

  1. 可解释性极好:树给出的就是一套“如果…那么…”的规则,老师、领导、客户都能看懂;
  2. 能自动做变量选择:真正参与分裂的变量往往就那么几个;
  3. 是很多集成算法的底座:随机森林、GBM、XGBoost 都是树的“加强版”。

本笔记是我 自己造了两个数据集
- 一个做回归树(模拟房价);
- 一个做分类树(模拟银行营销,正类比例大概 25%);
并且在每一段代码后面,都加了 “输出结果怎么读” 的解释。

2 准备工作

library(rpart)       # 决策树核心包
library(rpart.plot)  # 更好看的树
library(dplyr)       # 数据处理
library(ggplot2)     # 画图
library(showtext)
font_add("Heiti TC Light", regular = "STHeiti Light.ttc")
showtext_auto()

3 数据集A:模拟房价数据 → 回归树

3.1 构造数据

我们虚构一个城市里有 500 个小区,每个小区都有以下特征:

  • rooms:平均房间数,越多越贵;
  • dist_center:距离市中心的距离(公里),越远越便宜;
  • lstat:低社会经济地位人群比例,越高说明小区更偏,越便宜;
  • green:周边绿化覆盖率;
  • crime:犯罪率;
  • medv:我们想要预测的房价中位数(目标变量)。
set.seed(1)
n_house <- 500

house_df <- data.frame(
  rooms        = round(runif(n_house, 3, 8), 2),
  dist_center  = round(runif(n_house, 0.5, 20), 2),
  lstat        = round(runif(n_house, 2, 30), 2),
  green        = round(runif(n_house, 0.1, 0.8), 2),
  crime        = round(rexp(n_house, rate = 1/5), 2)  # 指数分布,凸显“部分区域犯罪率高”
)

# 人为设定生成房价的机制
house_df <- house_df %>% 
  mutate(
    medv = 50 +
      6  * rooms +             # 房间数多 → 房价升
      (-1.3) * dist_center +   # 离市中心远 → 房价降
      (-0.6) * lstat +         # 低社会经济地位比例高 → 房价降
      8  * green +             # 有绿化 → 房价升
      (-1.0) * crime +         # 犯罪率高 → 房价降
      rnorm(n_house, 0, 3)     # 一点噪声,模拟现实里不可观测因素
  )

str(house_df)
## 'data.frame':    500 obs. of  6 variables:
##  $ rooms      : num  4.33 4.86 5.86 7.54 4.01 7.49 7.72 6.3 6.15 3.31 ...
##  $ dist_center: num  11.31 13.92 13.33 13.44 9.71 ...
##  $ lstat      : num  16.86 21.18 12.73 28.74 5.31 ...
##  $ green      : num  0.15 0.71 0.24 0.47 0.17 0.37 0.62 0.59 0.17 0.53 ...
##  $ crime      : num  3 3.75 0.83 11.03 0.28 ...
##  $ medv       : num  49.1 49.8 59.5 54.7 60.7 ...
summary(house_df$medv)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   19.33   49.76   58.37   58.95   68.79   97.18

输出

  • str() 会告诉你一共有 500 行、6 个变量,都是数值型;
  • summary(medv) 一般会看到房价大部分集中在 40~80 左右,说明我们模拟得还算合理;
  • 我们后面要做的树,就是要从这 5 个自变量里挑出最能分房价的。

3.2 划分训练集和测试集

set.seed(2025)
id_train <- sample(seq_len(n_house), size = floor(0.7 * n_house))  # 70% 训练
house_train <- house_df[id_train, ]
house_test  <- house_df[-id_train, ]

我们保留 30% 不参与建模,只用来 考核模型在没见过的数据上的表现。这是判断树是不是“过拟合”的关键。

3.3 拟合回归树

fit_house <- rpart(
  medv ~ .,
  data = house_train,
  method = "anova"   # 回归树也可以不写,它能识别出 medv 是连续的
)

fit_house
## n= 350 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 350 66635.2800 59.32867  
##    2) rooms< 5.965 209 28288.5300 53.42257  
##      4) dist_center>=11.17 84  7579.0200 45.68170  
##        8) rooms< 4.365 41  2827.5120 39.75572  
##         16) crime>=2.87 25   782.3538 34.75291 *
##         17) crime< 2.87 16   441.7969 47.57260 *
##        9) rooms>=4.365 43  1938.8620 51.33205 *
##      5) dist_center< 11.17 125 12293.7100 58.62444  
##       10) crime>=12.24 13   748.0036 43.81439 *
##       11) crime< 12.24 112  8363.3540 60.34346  
##         22) rooms< 4.155 44  2390.8070 55.18801 *
##         23) rooms>=4.155 68  4046.3770 63.67934  
##           46) green< 0.61 52  2604.8940 61.34648  
##             92) dist_center>=5.575 27   599.9490 56.75208 *
##             93) dist_center< 5.575 25   819.4909 66.30843 *
##           47) green>=0.61 16   238.7483 71.26114 *
##    3) rooms>=5.965 141 20250.1500 68.08311  
##      6) dist_center>=11.265 67  5649.2670 60.27653  
##       12) lstat>=16.56 27  1617.0430 53.88235 *
##       13) lstat< 16.56 40  2183.1730 64.59260  
##         26) crime>=6.74 14   289.4671 57.79021 *
##         27) crime< 6.74 26   897.0673 68.25543 *
##      7) dist_center< 11.265 74  6820.8070 75.15122  
##       14) lstat>=19.4 21  1482.3060 66.39956  
##         28) crime>=2.78 12   636.7868 61.51844 *
##         29) crime< 2.78 9   178.4093 72.90773 *
##       15) lstat< 19.4 53  3092.7790 78.61886  
##         30) dist_center>=6.22 27   770.6203 74.29564 *
##         31) dist_center< 6.22 26  1293.4730 83.10837 *

输出

  • 第一行通常是 root ...,表示根节点,也就是“所有训练样本”;
  • 后面每一行都是一个节点:会写“用哪个变量在哪个阈值上分裂”;
  • 每行里常见的:
    • n:这个节点里有多少个样本;
    • deviance:回归问题里就是残差平方和;
    • yval:如果停在这个节点上,模型给出的预测值,就是这个子样本里的 平均房价
  • * 的行,是 终节点,说明不用再往下分了。

3.4 可视化树结构

op <- par(no.readonly = TRUE)
par(mar = c(1, 1, 1, 1))
plot(fit_house, margin = 0.05)
text(fit_house, use.n = TRUE, cex = 0.7)

par(op)

读图

  • 从上往下:越往下表示样本被分得越细;
  • 左边是“条件为假”,右边是“条件为真”(和你书上的是一回事);
  • 每个终节点上写的数,就是该节点的预测房价;
  • 如果你看到树里反复用 rooms 或者 dist_center 来分,说明这两个变量在这个模拟数据里 确实最有解释力

更好看的版本:

rpart.plot::prp(
  fit_house,
  type = 2,          # 分裂条件放在内部
  extra = 101,       # 节点里显示预测值 + 节点样本数
  fallen.leaves = TRUE
)

3.5 用交叉验证选树大小

plotcp(fit_house)

fit_house$cptable
##            CP nsplit rel error    xerror       xstd
## 1  0.27157694      0 1.0000000 1.0030137 0.07076405
## 2  0.12629647      1 0.7284231 0.7736633 0.05196276
## 3  0.11675604      2 0.6021266 0.7228256 0.05031177
## 4  0.04775778      3 0.4853706 0.5786681 0.04031116
## 5  0.04220956      4 0.4376128 0.5475991 0.03819379
## 6  0.03370170      5 0.3954032 0.5329764 0.03733510
## 7  0.02890616      6 0.3617015 0.4805375 0.03488561
## 8  0.02774882      7 0.3327954 0.4614401 0.03292669
## 9  0.02406175      8 0.3050465 0.4379398 0.03108787
## 10 0.01804953      9 0.2809848 0.4216290 0.03086947
## 11 0.01779018     10 0.2629353 0.4042117 0.03249943
## 12 0.01543755     11 0.2451451 0.3926860 0.03154869
## 13 0.01495662     12 0.2297075 0.3792125 0.03066014
## 14 0.01001136     13 0.2147509 0.3627746 0.02968993
## 15 0.01000000     14 0.2047396 0.3481261 0.02894511

plotcp()

  • 横轴上面那个数字是“树的大小”(终节点数);
  • 横轴下面那个是 cp(复杂度参数);
  • 纵轴是交叉验证误差,越低越好;
  • 通常我们会选最低点,或者用“1SE 规则”选一个更简单的点。

cptable

  • nsplit:分裂次数,= 终节点数 - 1;
  • rel error:在训练集上的相对误差;
  • xerror:交叉验证误差(重点看它);
  • xstd:交叉验证误差的标准误。

3.6 修枝(得到最优大小的树)

min_cp <- fit_house$cptable[which.min(fit_house$cptable[, "xerror"]), "CP"]
min_cp
## [1] 0.01
fit_house_best <- prune(fit_house, cp = min_cp)

rpart.plot::prp(
  fit_house_best,
  type = 2,
  extra = 101,
  fallen.leaves = TRUE,
  main = "修枝后的最优回归树"
)

这一步:我们不是要让树“长到最大”,而是要让树在“解释力”和“泛化能力”之间取得平衡。修枝之后的树,终节点通常会少一些,更稳。

3.7 测试集预测 + MSE

pred_house <- predict(fit_house_best, newdata = house_test)
y_true     <- house_test$medv

mse_tree <- mean((pred_house - y_true)^2)
mse_tree
## [1] 84.9063

MSE

  • 这是在“没见过的数据”上的平均平方误差;
  • 数值没有绝对好坏,要看你数据的量纲。比如房价在 40~80 之间,MSE 如果是 10~20 之间就还算OK;
  • 真正的比较意义是在“和别的模型比”的时候。

画一下实际值 vs 预测值:

plot(pred_house, y_true,
     xlab = "预测房价",
     ylab = "实际房价",
     main = "回归树预测效果(测试集)")
abline(0, 1, lwd = 2)

你会看到:点是“一坨一坨”的,而不是一条平滑的线,这是 树的典型特征 —— 每个终节点只能给一个数字。

3.8 和线性回归比较

lm_house <- lm(medv ~ ., data = house_train)
pred_lm  <- predict(lm_house, newdata = house_test)
mse_lm   <- mean((pred_lm - y_true)^2)

mse_tree
## [1] 84.9063
mse_lm
## [1] 12.00946

你会看到mse_lmmse_tree 小,这说明在我们这个模拟设定里,线性回归更贴合真实生成机制。

在本模拟数据中,回归树的可解释性更强,但预测精度不及OLS;如果希望提升树模型的预测性能,可进一步采用基于树的集成学习方法(如随机森林、梯度提升树等)。


4 数据集B:模拟银行营销数据 → 分类树

4.1 构造数据

set.seed(2)
n_bank <- 1500
job_levels <- c("admin", "blue", "services", "self", "retired")

bank_df <- data.frame(
  age = round(runif(n_bank, 20, 65)),
  job = factor(sample(job_levels, n_bank, replace = TRUE)),
  marital = factor(sample(c("single", "married"), n_bank, replace = TRUE, prob = c(0.35, 0.65))),
  last_contact_days = round(rexp(n_bank, rate = 1/15)),
  emp_rate = round(rnorm(n_bank, mean = 1.2, sd = 0.6), 2)
)

## 拉强信号:最近联系、已婚、中年、经济好 → 更可能买
linpred <- -2.6 +
  0.04 * (40 - abs(bank_df$age - 40)) +
  ifelse(bank_df$marital == "married", 0.6, 0) +
  ifelse(bank_df$last_contact_days < 10, 1.2, 0) +
  0.4 * bank_df$emp_rate

prob <- 1 / (1 + exp(-linpred))
y <- rbinom(n_bank, size = 1, prob = prob)
bank_df$y <- factor(ifelse(y == 1, "yes", "no"))

prop.table(table(bank_df$y))
## 
##        no       yes 
## 0.5313333 0.4686667

输出
prop.table(table(bank_df$y)) 会告诉你正负样本比例,比如:

  • no 大概 0.7
  • yes 大概 0.3

这就比原来 0.1 的例子好多了,树也更容易学到“yes”长什么样。

4.2 划分训练/测试集

set.seed(2026)
id_train_b <- sample(seq_len(n_bank), size = 1000)
bank_train <- bank_df[id_train_b, ]
bank_test  <- bank_df[-id_train_b, ]

4.3 拟合分类树(宽松控制)

fit_bank <- rpart(
  y ~ .,
  data = bank_train,
  method = "class",
  control = rpart.control(
    cp = 0.001,      # 比默认值 0.01 小得多
    minsplit = 20,
    minbucket = 7,
    xval = 10
  )
)

fit_bank
## n= 1000 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##    1) root 1000 456 no (0.54400000 0.45600000)  
##      2) last_contact_days>=9.5 535 165 no (0.69158879 0.30841121)  
##        4) emp_rate< 0.845 140  24 no (0.82857143 0.17142857) *
##        5) emp_rate>=0.845 395 141 no (0.64303797 0.35696203)  
##         10) marital=single 155  40 no (0.74193548 0.25806452)  
##           20) job=admin,retired 72  12 no (0.83333333 0.16666667) *
##           21) job=blue,self,services 83  28 no (0.66265060 0.33734940)  
##             42) emp_rate< 1.055 7   0 no (1.00000000 0.00000000) *
##             43) emp_rate>=1.055 76  28 no (0.63157895 0.36842105)  
##               86) age< 37 26   6 no (0.76923077 0.23076923) *
##               87) age>=37 50  22 no (0.56000000 0.44000000)  
##                174) age>=48 29   8 no (0.72413793 0.27586207) *
##                175) age< 48 21   7 yes (0.33333333 0.66666667) *
##         11) marital=married 240 101 no (0.57916667 0.42083333)  
##           22) age< 24.5 23   2 no (0.91304348 0.08695652) *
##           23) age>=24.5 217  99 no (0.54377880 0.45622120)  
##             46) age>=52.5 75  23 no (0.69333333 0.30666667)  
##               92) emp_rate< 1.37 38   6 no (0.84210526 0.15789474) *
##               93) emp_rate>=1.37 37  17 no (0.54054054 0.45945946)  
##                186) age< 56.5 10   1 no (0.90000000 0.10000000) *
##                187) age>=56.5 27  11 yes (0.40740741 0.59259259)  
##                  374) job=admin,blue,services 16   7 no (0.56250000 0.43750000) *
##                  375) job=retired,self 11   2 yes (0.18181818 0.81818182) *
##             47) age< 52.5 142  66 yes (0.46478873 0.53521127)  
##               94) emp_rate>=1.015 123  60 no (0.51219512 0.48780488)  
##                188) emp_rate< 1.155 17   3 no (0.82352941 0.17647059) *
##                189) emp_rate>=1.155 106  49 yes (0.46226415 0.53773585)  
##                  378) age< 30.5 27   9 no (0.66666667 0.33333333) *
##                  379) age>=30.5 79  31 yes (0.39240506 0.60759494)  
##                    758) last_contact_days< 25 48  23 yes (0.47916667 0.52083333)  
##                     1516) last_contact_days>=10.5 41  19 no (0.53658537 0.46341463)  
##                       3032) age>=43.5 18   4 no (0.77777778 0.22222222) *
##                       3033) age< 43.5 23   8 yes (0.34782609 0.65217391)  
##                         6066) emp_rate< 1.485 8   2 no (0.75000000 0.25000000) *
##                         6067) emp_rate>=1.485 15   2 yes (0.13333333 0.86666667) *
##                     1517) last_contact_days< 10.5 7   1 yes (0.14285714 0.85714286) *
##                    759) last_contact_days>=25 31   8 yes (0.25806452 0.74193548) *
##               95) emp_rate< 1.015 19   3 yes (0.15789474 0.84210526) *
##      3) last_contact_days< 9.5 465 174 yes (0.37419355 0.62580645)  
##        6) age>=59.5 68  29 no (0.57352941 0.42647059)  
##         12) marital=single 28   7 no (0.75000000 0.25000000) *
##         13) marital=married 40  18 yes (0.45000000 0.55000000)  
##           26) job=admin,services 19   8 no (0.57894737 0.42105263) *
##           27) job=blue,retired,self 21   7 yes (0.33333333 0.66666667)  
##             54) emp_rate>=1.655 7   3 no (0.57142857 0.42857143) *
##             55) emp_rate< 1.655 14   3 yes (0.21428571 0.78571429) *
##        7) age< 59.5 397 135 yes (0.34005038 0.65994962)  
##         14) age< 28.5 104  48 yes (0.46153846 0.53846154)  
##           28) job=retired 21   4 no (0.80952381 0.19047619) *
##           29) job=admin,blue,self,services 83  31 yes (0.37349398 0.62650602)  
##             58) emp_rate< 0.485 11   3 no (0.72727273 0.27272727) *
##             59) emp_rate>=0.485 72  23 yes (0.31944444 0.68055556) *
##         15) age>=28.5 293  87 yes (0.29692833 0.70307167)  
##           30) age>=49.5 96  38 yes (0.39583333 0.60416667)  
##             60) last_contact_days< 2.5 33  14 no (0.57575758 0.42424242)  
##              120) emp_rate< 1.33 17   4 no (0.76470588 0.23529412) *
##              121) emp_rate>=1.33 16   6 yes (0.37500000 0.62500000) *
##             61) last_contact_days>=2.5 63  19 yes (0.30158730 0.69841270)  
##              122) emp_rate< 0.57 10   4 no (0.60000000 0.40000000) *
##              123) emp_rate>=0.57 53  13 yes (0.24528302 0.75471698) *
##           31) age< 49.5 197  49 yes (0.24873096 0.75126904)  
##             62) age< 40.5 115  35 yes (0.30434783 0.69565217)  
##              124) job=admin,self,services 72  27 yes (0.37500000 0.62500000)  
##                248) emp_rate>=0.465 63  26 yes (0.41269841 0.58730159)  
##                  496) last_contact_days>=7.5 14   6 no (0.57142857 0.42857143) *
##                  497) last_contact_days< 7.5 49  18 yes (0.36734694 0.63265306)  
##                    994) emp_rate< 0.885 13   6 no (0.53846154 0.46153846) *
##                    995) emp_rate>=0.885 36  11 yes (0.30555556 0.69444444) *
##                249) emp_rate< 0.465 9   1 yes (0.11111111 0.88888889) *
##              125) job=blue,retired 43   8 yes (0.18604651 0.81395349) *
##             63) age>=40.5 82  14 yes (0.17073171 0.82926829)  
##              126) emp_rate>=1.3 42  11 yes (0.26190476 0.73809524)  
##                252) emp_rate< 1.385 7   3 no (0.57142857 0.42857143) *
##                253) emp_rate>=1.385 35   7 yes (0.20000000 0.80000000)  
##                  506) age>=46.5 7   3 no (0.57142857 0.42857143) *
##                  507) age< 46.5 28   3 yes (0.10714286 0.89285714) *
##              127) emp_rate< 1.3 40   3 yes (0.07500000 0.92500000) *

输出
- 第一行还是根节点:会告诉你有 1000 个样本,多少是 no,多少是 yes
- 如果你看到下面又出现了 last_contact_days < 某个值marital = married 之类的分裂,说明我们前面造的“信号”被树找到了。
- 如果这里仍然只剩一个根节点,那就说明你这一次随机出来的数据恰好比较“弱”,可以把 cp 再降一点,比如 0.0005。

4.4 交叉验证图 & 修枝

plotcp(fit_bank)

fit_bank$cptable
##             CP nsplit rel error    xerror       xstd
## 1  0.256578947      0 1.0000000 1.0000000 0.03453958
## 2  0.021929825      1 0.7434211 0.7434211 0.03282734
## 3  0.014254386      2 0.7214912 0.7587719 0.03298846
## 4  0.010964912      4 0.6929825 0.7390351 0.03277984
## 5  0.008771930      5 0.6820175 0.7543860 0.03294324
## 6  0.006578947      6 0.6732456 0.7675439 0.03307698
## 7  0.005482456      7 0.6666667 0.7785088 0.03318403
## 8  0.004385965     23 0.5394737 0.7697368 0.03309871
## 9  0.003837719     25 0.5307018 0.7543860 0.03294324
## 10 0.002192982     29 0.5153509 0.7412281 0.03280368
## 11 0.001253133     30 0.5131579 0.7785088 0.03318403
## 12 0.001000000     38 0.5021930 0.7960526 0.03334711

和回归树一样,看哪一个 cp 的交叉验证误差最低:

min_cp_b <- fit_bank$cptable[which.min(fit_bank$cptable[, "xerror"]), "CP"]
min_cp_b
## [1] 0.01096491
fit_bank_best <- prune(fit_bank, cp = min_cp_b)

rpart.plot::prp(
  fit_bank_best,
  type = 2,
  extra = 104,
  fallen.leaves = TRUE,
  main = "修枝后的最优分类树"
)

  • 每个内部节点是一个“筛人规则”;
  • 叶子节点会显示预测的类别(yes/no)、该叶子里 yes/no 的比例、样本量;
  • 如果你看到某个叶子里 yes 的比例明显比总体 30% 高,比如 55%,那就是一群比较“值得营销”的人。

4.5 测试集预测 + 混淆矩阵

# 默认阈值 0.5
pred_cls <- predict(fit_bank_best, newdata = bank_test, type = "class")
tab <- table(pred_cls, bank_test$y)
tab
##         
## pred_cls  no yes
##      no  183 123
##      yes  70 124
accuracy <- sum(diag(tab)) / sum(tab)
accuracy
## [1] 0.614
# 防呆:有时候树会把所有人都判成一个类,下面这样写就不会报错
TP <- if ("yes" %in% rownames(tab)) tab["yes", "yes"] else 0
FN <- if ("no"  %in% rownames(tab)) tab["no",  "yes"] else 0

sensitivity <- TP / (TP + FN)
sensitivity
## [1] 0.5020243

解释结果

  • accuracy 是整体准确率;
  • sensitivity 是对 “yes” 这一类的识别能力;
  • 营销/医学筛查更看重的是 “sensitivity”。

4.6 调低阈值,看灵敏度怎么变

# 取出预测概率
prob_test <- predict(fit_bank_best, newdata = bank_test, type = "prob")[, "yes"]

# 把阈值从0.5降到0.25
pred_025  <- ifelse(prob_test > 0.25, "yes", "no")

tab2 <- table(pred_025, bank_test$y)
tab2
##         
## pred_025  no yes
##      no    2   7
##      yes 251 240
acc2 <- sum(diag(tab2)) / sum(tab2)

TP2 <- if ("yes" %in% rownames(tab2)) tab2["yes", "yes"] else 0
FN2 <- if ("no"  %in% rownames(tab2)) tab2["no",  "yes"] else 0

sens2 <- TP2 / (TP2 + FN2)

acc2
## [1] 0.484
sens2
## [1] 0.9716599

报告

将分类阈值从0.5下调至0.25后,模型的总体准确率略有下降,但对“有购买意向”的识别率(灵敏度)明显提升,说明在该业务场景下通过牺牲部分准确率可换取更高的召回率,这与营销业务“宁可多打几个错的电话,也不能漏掉潜在客户”的目标是一致的。


5 小结与延伸

  1. rpart 会根据因变量自动识别是回归树还是分类树;
  2. 树的三个关键步骤:建 → 看 cp → 修枝
  3. 模拟数据时,如果树老长不出来,通常是 信号不够强 或者 cp 太大
  4. 决策树的预测值是一段一段的,如果想要更平滑的预测,要上随机森林/GBM;
  5. 分类问题下一定要看混淆矩阵和灵敏度,不能只看准确率。