【知识点】决策树算法详解:ID3、C4.5 和 CART 的原理与实现
编辑
2024-09-01
深度学习
00

目录

决策树算法介绍
1. ID3(Iterative Dichotomiser 3)——决策树的早期经典算法
🎯 核心思想:
🔍 怎么挑“最有用的特征”?
🧠 什么是“信息增益”?
📐 公式说明(简单版):
✅ 总结一下大白话版本:
2. C4.5 —— ID3 的升级版,更聪明的决策树算法!
🎯 核心改进点:
🔍 怎么挑特征?—— 使用“信息增益率”
🧠 通俗解释一下这个公式:
📐 固有值(IV)公式也简单:
✅ 总结一下大白话版本:
3. CART (Classification and Regression Trees)—— 又能分类又能回归的全能型决策树!
🎯 核心特点:
🔍 一、特征选择标准
✅ 通俗讲解:基尼指数(Gini Index)
🧠 想象一下你手里有一堆球,红的、蓝的都有。
📐 公式:
🧩 CART 在分类时会怎么做?
✅ 通俗讲解:均方误差(MSE)
🧠 这个用于回归问题。
📐 公式:
🧩 CART 在回归时会怎么做?
✅ 总结一下大白话版本:
代码实战
总结

决策树算法介绍

决策树是一种常用的监督学习方法,广泛应用于分类和回归任务。它通过一系列的分割规则,将数据集分成更小的子集,最终形成一个树形结构。每个节点代表一个特征,每个分支代表该特征的一个可能值,每个叶节点代表一个类或回归值。

以下是三种常见的决策树算法及其特征选择标准:

1. ID3(Iterative Dichotomiser 3)——决策树的早期经典算法

🎯 核心思想

ID3 的目标是构建一棵决策树,用来做分类。它通过不断地挑选“最有用的特征”来划分数据,直到把数据分得“足够清楚”。


🔍 怎么挑“最有用的特征”?

ID3 使用的是:信息增益(Information Gain)

你可以理解为:

哪个特征能让我更“确定”地知道样本属于哪一类,我就先用它来分!


🧠 什么是“信息增益”?

举个通俗的例子:

假设你手里有一堆人,你想判断他们是“喜欢打篮球”还是“不喜欢打篮球”。现在你有几个特征可以看:比如性别、年龄、身高……

  • 如果你看“性别”,发现男女混在一起,分不太清,那这个特征就帮不了你太多。
  • 但如果你看“身高”,发现高的人几乎都爱打篮球,矮的不爱打,那这个特征就很“有用”。

所以,“信息增益”就是衡量一个特征能把混乱的数据变多“清晰”的程度。


📐 公式说明(简单版)

信息增益 = 分类前的混乱程度(熵) - 按某个特征分后的平均混乱程度

公式是这样的:

IG(T,A)=Entropy(T)vValues(A)TvTEntropy(Tv)IG(T, A) = Entropy(T) - \sum_{v \in Values(A)} \frac{|T_v|}{|T|} \cdot Entropy(T_v)

别怕,其实意思很简单:

  • Entropy(T)Entropy(T):当前数据有多乱(不确定性有多大)
  • TvT_v:按特征 A 的每个值分开后的小数据集
  • 后面那一串意思是:我用 A 把数据分几份,每份的混乱度乘上它的比例,加起来得到一个新的“平均混乱度”
  • IG 就是原来的混乱度减去这个新的混乱度 —— 差越大,说明这个特征越有用!

✅ 总结一下大白话版本:

ID3 是一种建树方法,它每次选一个最能“理清楚数据”的特征来切分数据,而它判断哪个特征最厉害的标准,叫做“信息增益”。谁的信息增益大,谁就先上!

2. C4.5 —— ID3 的升级版,更聪明的决策树算法!

🎯 核心改进点

C4.5 是在 ID3 的基础上改进而来的,主要解决了 ID3 的一个“小毛病”:

它太偏爱那些取值多的特征!

比如你有两个特征:

  • “性别”:只有两个值(男/女)
  • “身份证号”:每个样本几乎都有不同的值

ID3 可能会认为“身份证号”这个特征特别有用,因为它可以把数据分得非常细,信息增益很高。但其实这种特征根本没意义,因为每个值都唯一,没法用来泛化预测。

所以,C4.5 就是来纠正这个偏见的


🔍 怎么挑特征?—— 使用“信息增益率”

C4.5 不再直接用“信息增益”,而是用它的“亲戚”——信息增益率(Gain Ratio)

你可以理解为:

增益率 = “你的贡献有多大” ÷ “你本身有多能分”

这样就能惩罚那些“分得太碎”的特征了。


🧠 通俗解释一下这个公式

GR(A)=IG(T,A)IV(A)GR(A) = \frac{IG(T, A)}{IV(A)}
  • IG(T,A)IG(T, A):还是原来的信息增益,表示这个特征对分类的帮助有多大。
  • IV(A)IV(A):叫做“固有值”(Intrinsic Value),表示这个特征自己有多少“分裂能力”,也就是它有多少个取值、分布得多散。

举个例子:

  • 如果一个特征把数据分成100份(比如身份证号),那它的 IV 很大;
  • 如果一个特征只分成2份(比如性别),那它的 IV 就小。

所以即使它的信息增益高,如果它本身太“分裂”,增益率就会被拉低。


📐 固有值(IV)公式也简单

IV(A)=i=1nTiTlog2TiTIV(A) = - \sum_{i=1}^{n} \frac{|T_i|}{|T|} \log_2 \frac{|T_i|}{|T|}

这其实就是计算这个特征本身的“熵”。
也就是说,它衡量的是:这个特征自己能把数据分得多乱或多清晰。


✅ 总结一下大白话版本:

C4.5 是 ID3 的升级版,它不再单纯看哪个特征分得清楚,还要看这个特征是不是“动不动就把数据分得乱七八糟”。它用“信息增益率”来平衡这两方面,避免模型偏爱那些取值很多但没啥实际意义的特征。

3. CART (Classification and Regression Trees)—— 又能分类又能回归的全能型决策树!

🎯 核心特点

CART 是一个“通吃型”选手:
✅ 它可以用来做分类(Classify),
✅ 也可以用来做回归(Regression)。

也就是说,不管是判断“你是男是女”,还是预测“你能考多少分”,它都能搞定!


🔍 一、特征选择标准

CART 根据任务不同,使用不同的标准:

任务类型使用的标准
分类基尼指数(Gini)
回归均方误差(MSE)

✅ 通俗讲解:基尼指数(Gini Index)

🧠 想象一下你手里有一堆球,红的、蓝的都有。

  • 如果这一堆全是红球 → 很纯,不纯度低,Gini 就小;
  • 如果红蓝混在一起 → 不纯,Gini 就大。

所以:

基尼指数就是衡量一个数据集有多“混乱”或“不纯”。


📐 公式:

Gini(T)=1i=1nPi2Gini(T) = 1 - \sum_{i=1}^{n} P_i^2
  • PiP_i 是第 ii 类出现的概率。
  • 举个例子:如果一个数据集中有 70% 是 A 类,30% 是 B 类:
    Gini=1(0.72+0.32)=1(0.49+0.09)=0.42Gini = 1 - (0.7^2 + 0.3^2) = 1 - (0.49 + 0.09) = 0.42

🧩 CART 在分类时会怎么做?

每次选一个特征,尝试切一刀,看看能不能把数据分得更“干净”一点。哪个特征能让 Gini 最小,就选谁!


✅ 通俗讲解:均方误差(MSE)

🧠 这个用于回归问题。

比如你想预测一个人的工资是多少,模型预测了 15k,但实际是 18k,那这个差距就是误差。

MSE 就是平均来看,模型预测值和真实值之间的“平方差”。


📐 公式:

MSE=1Ni=1N(yiy^i)2MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
  • yiy_i 是真实的值(比如工资)
  • y^i\hat{y}_i 是模型预测的值
  • N 是样本数量

🧩 CART 在回归时会怎么做?

每次选一个特征,找一个最合适的切分点,让左右两边的 MSE 最小。也就是说,让预测尽可能接近真实值。


✅ 总结一下大白话版本:

CART 是一个既能分类又能回归的决策树算法。

  • 分类的时候看“基尼指数”——谁能让数据分得更干净,就选谁;
  • 回归的时候看“均方误差”——谁能让预测更准,就选谁。

如果你想继续了解剪枝、过拟合、随机森林这些内容,我也可以继续用这种通俗方式讲 😄

代码实战

下面是一个使用Scikit-learn库实现决策树分类器的示例代码:

python
展开代码
import numpy as np import pandas as pd from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score from sklearn.tree import export_text # 加载数据集(这里使用鸢尾花数据集作为示例) from sklearn.datasets import load_iris data = load_iris() X = data.data y = data.target # 拆分数据集为训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) # 初始化决策树分类器(使用CART算法) clf = DecisionTreeClassifier(criterion='gini', random_state=42) # 训练模型 clf.fit(X_train, y_train) # 预测 y_pred = clf.predict(X_test) # 计算准确率 accuracy = accuracy_score(y_test, y_pred) print(f"模型准确率: {accuracy * 100:.2f}%") # 导出决策树规则 tree_rules = export_text(clf, feature_names=data.feature_names) print(tree_rules)

总结

  • ID3:使用信息增益选择特征。

  • C4.5:使用信息增益率选择特征,以避免信息增益的偏好。

  • CART:使用基尼指数(分类)或均方误差(回归)选择特征。

每种决策树算法在不同的应用场景中有不同的优势,选择哪种算法取决于具体任务和数据特征。希望本文能帮助你更好地理解决策树算法及其实现。

如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!