决策树:像做选择题一样理解分类与回归
决策树(Decision Tree)是一种非常符合人类直觉的机器学习算法。它的预测过程就像做选择题:先问一个问题,根据答案走到下一步,再继续问问题,直到得到最终结论。
例如判断一个学生考试是否通过,可以问:
- 学习时间是否大于 3 小时?
- 睡眠时间是否大于 7 小时?
- 是否完成复习?
最后走到某个叶子节点,得到“通过”或“不通过”的预测结果。
1. 决策树的基本思想
决策树的核心思想是:不断选择一个最合适的特征,把数据集切分得越来越“纯”。
这里的“纯”可以理解为:一个节点里的样本类别越统一,就越纯。
例如一个节点里有 10 个样本:
- 如果 10 个都是“通过”,这个节点非常纯。
- 如果 5 个“通过”、5 个“不通过”,这个节点就很混乱。
训练决策树时,算法会不断寻找最佳划分条件,例如:
学习时间 >= 3 小时?
如果这个问题能把“通过”和“不通过”分得更开,就说明它是一个不错的划分。
2. 决策树由哪些部分组成
一棵决策树通常包含三类节点:
| 结构 | 含义 |
|---|---|
| 根节点 | 整棵树的起点,包含全部训练数据 |
| 内部节点 | 一个判断条件,例如“年龄是否大于 18” |
| 叶子节点 | 最终预测结果,例如“通过”或“不通过” |
从根节点走到叶子节点,就是模型预测一个样本的过程。
例如:
学习时间 >= 3 小时?
是 -> 完成复习?
是 -> 预测:通过
否 -> 预测:不通过
否 -> 睡眠时间 >= 7 小时?
是 -> 预测:通过
否 -> 预测:不通过
这种结构的好处是非常容易解释。相比逻辑回归、支持向量机等模型,决策树更像是一套可读的判断规则。
3. 如何判断哪个特征更适合划分
决策树训练时最关键的问题是:
当前节点应该用哪个特征来划分?
常见的划分标准包括:
- 信息增益
- 信息增益率
- 基尼指数
- 均方误差
其中分类树常用信息熵、信息增益和基尼指数;回归树常用均方误差。
4. 信息熵:衡量混乱程度
信息熵(Entropy)可以用来衡量一个数据集的混乱程度。
对于分类问题,信息熵定义为:
$$
H(D) = -\sum_{k=1}^{K} p_k \log_2 p_k
$$
其中:
- $D$ 表示当前数据集。
- $K$ 表示类别数量。
- $p_k$ 表示第 $k$ 类样本所占比例。
如果一个节点里所有样本都属于同一类,那么熵为 0,说明完全不混乱。
如果不同类别各占一半,熵会比较大,说明节点比较混乱。
以二分类为例:
| 通过 | 不通过 | 熵 |
|---|---|---|
| 10 | 0 | 0 |
| 8 | 2 | 较小 |
| 5 | 5 | 较大 |
决策树希望每次划分后,子节点的熵尽可能小。
5. ID3 决策树:用信息增益选择特征
ID3(Iterative Dichotomiser 3)是比较早期的决策树算法。它的核心想法很直接:
每次选择“让信息熵下降最多”的特征。
这个“下降最多”就是信息增益(Information Gain)。假设用特征 $A$ 对数据集 $D$ 进行划分,信息增益为:
$$
Gain(D, A) = H(D) - \sum_{v=1}^{V}\frac{|D_v|}{|D|}H(D_v)
$$
其中:
- $H(D)$ 是划分前的熵。
- $D_v$ 是按特征 $A$ 划分后的第 $v$ 个子集。
- $\frac{|D_v|}{|D|}$ 表示子集占原数据集的比例。
信息增益越大,说明这个特征越能把混乱的数据分清楚。
来看一个小例子。假设我们要判断是否适合打球:
| 编号 | 天气 | 是否周末 | 是否打球 |
|---|---|---|---|
| 1 | 晴 | 是 | 否 |
| 2 | 晴 | 否 | 否 |
| 3 | 阴 | 是 | 是 |
| 4 | 阴 | 否 | 是 |
| 5 | 雨 | 是 | 是 |
| 6 | 雨 | 否 | 是 |
当前数据集中有 4 个“是”、2 个“否”,所以整体信息熵为:
$$
H(D) = -\frac{4}{6}\log_2\frac{4}{6} - \frac{2}{6}\log_2\frac{2}{6} \approx 0.918
$$
如果按照“天气”划分:
| 天气 | 子集情况 | 子集熵 |
|---|---|---|
| 晴 | 0 个是,2 个否 | 0 |
| 阴 | 2 个是,0 个否 | 0 |
| 雨 | 2 个是,0 个否 | 0 |
计算其中一个$H(D_晴)$,因天气=晴只有[否, 否],数据绝对纯净,代入计算熵的公式可知:
$$
H(D_{晴}) = - \frac{2}{2}\log_{2}\frac{2}{2} - 0 = 0
$$
故划分后的加权熵为 0
$$
加权熵 =\sum_{v=1}^{V}\frac{|D_v|}{|D|}H(D_v) = (\frac{2}{6} \times 0) + (\frac{2}{6} \times 0) + (\frac{2}{6} \times 0) = 0
$$
因此:
$$
Gain(D, 天气) = 0.918 - 0 = 0.918
$$
这说明“天气”这个特征一分下去,数据立刻变得非常纯。ID3 会优先选择它作为根节点。
ID3 的实现步骤可以概括为:
- 计算当前数据集的信息熵。
- 分别计算每个候选特征的信息增益。
- 选择信息增益最大的特征作为当前节点的划分条件。
- 按该特征的不同取值生成多个分支。
- 对每个子集递归重复上述过程,直到样本已经纯净、特征用完,或者达到停止条件。
ID3 的问题也很明显:它容易偏爱取值很多的特征。比如“用户 ID”这种特征,几乎每个样本都不同,一划分就能让子节点很纯,但这种划分对新数据没有泛化能力。
6. C4.5 决策树:用信息增益率修正偏好
C4.5 可以看成是 ID3 的改进版。它仍然关心信息增益,但不会只看信息增益,而是使用信息增益率(Gain Ratio)。
信息增益率的定义为:
$$
GainRatio(D, A) = \frac{Gain(D, A)}{SplitInfo(D, A)}
$$
其中 $SplitInfo(D,A)$ 表示特征 $A$ 自身把数据切得有多碎:
$$
SplitInfo(D, A) = -\sum_{v=1}^{V}\frac{|D_v|}{|D|}\log_2\frac{|D_v|}{|D|}
$$
可以把它理解成一种惩罚项:如果一个特征把数据切得太零散,那么分母会变大,最终的信息增益率会被压低。
继续用刚才的数据。如果额外加入一个“编号 ID”特征:
| 编号ID | 天气 | 是否打球 |
|---|---|---|
| 1 | 晴 | 否 |
| 2 | 晴 | 否 |
| 3 | 阴 | 是 |
| 4 | 阴 | 是 |
| 5 | 雨 | 是 |
| 6 | 雨 | 是 |
如果按“编号 ID”划分,每个子节点只有一个样本,子节点熵全是 0,所以信息增益会很大:
$$
Gain(D, 编号ID) = H(D) \approx 0.918
$$
但它把 6 个样本切成了 6 份,划分信息为:
$$
SplitInfo(D, 编号ID) = -6 \times \frac{1}{6}\log_2\frac{1}{6} = \log_2 6 \approx 2.585
$$
所以信息增益率为:
$$
GainRatio(D, 编号ID) \approx \frac{0.918}{2.585} \approx 0.355
$$
C4.5 不会轻易被“编号 ID”这种看起来很纯、实际没意义的特征骗走。它更倾向于选择既能降低混乱程度,又不会把数据切得过碎的特征。
C4.5 的实现步骤可以概括为:
- 计算当前数据集的信息熵。
- 对每个候选特征计算信息增益。
- 对每个候选特征计算划分信息。
- 用信息增益除以划分信息,得到信息增益率。
- 选择信息增益率较高且信息增益不太低的特征作为划分条件。
- 如果是连续特征,尝试多个候选阈值,把连续值转成“是否小于等于某个阈值”的划分。
- 递归生成子树,并可以配合剪枝减少过拟合。
相比 ID3,C4.5 更适合真实数据:它可以处理连续特征,也对缺失值和剪枝有更完整的考虑。
7. CART 树:用基尼指数或平方误差做二叉划分
CART(Classification And Regression Tree)既可以做分类,也可以做回归。它和 ID3、C4.5 一个很大的区别是:
CART 每次都做二叉划分。
即使一个特征有很多取值,CART 也会把问题拆成“左子树”和“右子树”。例如:
学习时间 <= 3.5 小时?
是 -> 左子树
否 -> 右子树
做分类时,CART 常用基尼指数(Gini Index)作为划分标准。
基尼指数定义为:
$$
Gini(D) = 1 - \sum_{k=1}^{K}p_k^2
$$
$p_k$ 表示第 $k$ 类样本所占比例。
它可以理解为:从当前节点随机抽一个样本,再随机按类别比例猜它的类别,猜错的概率大不大。
如果节点越纯,基尼指数越小。
二分类下:
- 如果一个节点全是正类,$Gini=0$。
- 如果正负类各占一半,$Gini=0.5$。
使用特征 $A$ 划分后的基尼指数为:
$$
Gini(D, A) = \sum_{v=1}^{V}\frac{|D_v|}{|D|}Gini(D_v)
$$
训练时会选择让划分后基尼指数最小的特征或阈值。
来看一个 CART 分类树的例子:
| 编号 | 学习时间 | 是否通过 |
|---|---|---|
| 1 | 1 | 否 |
| 2 | 2 | 否 |
| 3 | 3 | 否 |
| 4 | 4 | 是 |
| 5 | 5 | 是 |
| 6 | 6 | 是 |
如果尝试阈值 学习时间 <= 3.5:
- 左子树:3 个“否”,$Gini=0$。
- 右子树:3 个“是”,$Gini=0$。
- 加权基尼指数:$0$。
如果尝试阈值 学习时间 <= 2.5:
- 左子树:2 个“否”,$Gini=0$。
- 右子树:3 个“是”、1 个“否”,$Gini=1-(\frac{3}{4})^2-(\frac{1}{4})^2=0.375$。
- 加权基尼指数:$\frac{2}{6}\times0+\frac{4}{6}\times0.375=0.25$。
因为 0 小于 0.25,所以在这个例子中,学习时间 <= 3.5 是更好的划分。
CART 分类树的实现步骤可以概括为:
- 枚举候选特征。
- 对连续特征枚举候选阈值,对离散特征枚举可能的二叉划分。
- 按候选条件把数据分成左子树和右子树。
- 计算划分后的加权基尼指数。
- 选择加权基尼指数最小的划分。
- 递归生成左右子树,叶子节点输出多数类别。
做回归时,CART 不再使用基尼指数,而是常用平方误差作为划分标准:
$$
SSE(D) = \sum_{i=1}^{m}(y^{(i)}-\bar{y})^2
$$
它会选择让左右子树平方误差之和最小的划分,叶子节点通常输出该叶子中样本标签的平均值。
三种经典决策树可以这样对比:
| 算法 | 核心指标 | 树结构 | 主要任务 | 特点 |
|---|---|---|---|---|
| ID3 | 信息增益 | 多叉树 | 分类 | 简单直观,但容易偏爱取值多的特征 |
| C4.5 | 信息增益率 | 多叉树 | 分类 | 修正 ID3 的偏好,可处理连续特征和缺失值 |
| CART | 基尼指数或平方误差 | 二叉树 | 分类与回归 | 工程中很常用,也是随机森林、GBDT 等方法的重要基础 |
8. 回归树怎么划分
决策树不仅能做分类,也能做回归。
分类树的叶子节点输出类别;回归树的叶子节点输出数值。
例如预测房价时,一个叶子节点中可能有多个训练样本:
90 万、95 万、100 万
那么这个叶子节点的预测值通常取平均值:
$$
\hat{y} = \frac{1}{m}\sum_{i=1}^{m}y^{(i)}
$$
回归树在选择划分时,通常希望划分后的平方误差最小:
$$
MSE = \frac{1}{m}\sum_{i=1}^{m}(y^{(i)}-\hat{y})^2
$$
所以分类树更关心“类别是否更纯”,回归树更关心“数值是否更接近”。
9. 决策树为什么容易过拟合
决策树很容易把训练数据记得太细。
如果不加限制,它可以一直划分,直到每个叶子节点几乎只剩一个样本。这样训练集准确率可能非常高,但对新数据的泛化能力会下降。
这就是过拟合。
常见控制方法包括:
- 限制树的最大深度
max_depth。 - 限制叶子节点最少样本数
min_samples_leaf。 - 限制内部节点继续划分所需的最少样本数
min_samples_split。 - 剪枝,去掉贡献不大的分支。
可以把这些参数理解为“不要让树长得太复杂”。
10. Python 手写:计算基尼指数并选择划分
下面不急着手写完整决策树,而是先手写决策树最核心的一步:计算某个划分是否让数据更纯。
from collections import Counter
def gini(labels):
# 对应模块:计算当前节点的基尼指数,值越小表示类别越纯
total = len(labels)
counts = Counter(labels)
impurity = 1.0
for count in counts.values():
prob = count / total
impurity -= prob ** 2
return impurity
def split_dataset(X, y, feature_index, threshold):
# 对应模块:按照某个特征阈值把数据切成左右两个子节点
left_y = []
right_y = []
for sample, label in zip(X, y):
if sample[feature_index] <= threshold:
left_y.append(label)
else:
right_y.append(label)
return left_y, right_y
def weighted_gini(left_y, right_y):
# 对应模块:计算划分后的加权基尼指数
total = len(left_y) + len(right_y)
left_weight = len(left_y) / total
right_weight = len(right_y) / total
return left_weight * gini(left_y) + right_weight * gini(right_y)
X = [
[2, 6], # 学习 2 小时,睡眠 6 小时
[3, 7],
[4, 6],
[5, 8],
[6, 7],
]
y = ["不通过", "不通过", "通过", "通过", "通过"]
# 尝试用第 0 个特征,也就是学习时间,以 3.5 为阈值划分
left_y, right_y = split_dataset(X, y, feature_index=0, threshold=3.5)
print("左节点标签:", left_y)
print("右节点标签:", right_y)
print("划分前基尼指数:", gini(y))
print("划分后加权基尼指数:", weighted_gini(left_y, right_y))
这段代码展示了决策树选择划分的核心逻辑:
- 先选一个特征和阈值。
- 按条件把样本分成左右两组。
- 分别计算左右节点的纯度。
- 如果划分后的加权基尼指数更小,说明这个划分更好。
完整决策树就是不断重复这个过程,直到满足停止条件。
11. 使用 sklearn 训练决策树分类器
实际项目中可以直接使用 sklearn.tree.DecisionTreeClassifier。
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_text
# 特征:学习时间、睡眠时间
X = np.array([
[2, 6],
[3, 7],
[4, 6],
[5, 8],
[6, 7],
[1, 5],
[7, 8],
])
y = np.array(["不通过", "不通过", "通过", "通过", "通过", "不通过", "通过"])
# max_depth 限制树的深度,避免树过于复杂
model = DecisionTreeClassifier(
criterion="gini",
max_depth=2,
random_state=42,
)
model.fit(X, y)
new_student = np.array([[4, 7]])
prediction = model.predict(new_student)
print("预测结果:", prediction[0])
# 打印树的规则,方便理解模型是怎么做判断的
rules = export_text(
model,
feature_names=["学习时间", "睡眠时间"],
)
print(rules)
可能输出类似:
|--- 学习时间 <= 3.50
| |--- class: 不通过
|--- 学习时间 > 3.50
| |--- class: 通过
这也是决策树的一个重要优点:模型结果相对容易解释。
12. 决策树的优缺点
优点
- 可解释性强,预测过程像一组规则。
- 可以处理非线性关系。
- 对特征缩放不敏感,通常不需要标准化。
- 既能做分类,也能做回归。
- 可以处理数值特征和类别特征。
缺点
- 容易过拟合。
- 对数据扰动比较敏感,数据稍微变化,树结构可能明显改变。
- 单棵树的泛化能力有限。
- 贪心划分不一定能找到全局最优树。
这些缺点也解释了为什么后续会出现随机森林、梯度提升树、XGBoost、LightGBM 等集成模型。它们本质上都是在决策树基础上进行改进。
13. 实践建议
使用决策树时,可以重点关注以下几点:
- 先限制
max_depth,不要让树长得太深。 - 使用交叉验证选择
max_depth、min_samples_leaf等参数。 - 如果数据噪声较多,适当增大
min_samples_leaf。 - 如果单棵树效果不稳定,可以尝试随机森林。
- 如果追求更强的预测性能,可以学习梯度提升树相关算法。
14. 总结
决策树的核心思想很简单:不断提出问题,把数据切分得越来越纯。
对于分类任务,ID3 使用信息增益,C4.5 使用信息增益率,CART 分类树常用基尼指数;对于回归任务,CART 回归树通常使用平方误差或均方误差来选择划分。
理解决策树时,可以抓住三句话:
- 节点是在问问题。
- 分支是问题的不同答案。
- 叶子节点是最终预测结果。
决策树本身简单直观,但它也是很多强大模型的基础。理解决策树之后,再学习随机森林、GBDT、XGBoost 等算法时,会更容易抓住它们的核心思想。
