决策树是一种常用的监督学习方法,广泛应用于分类和回归任务。它通过一系列的分割规则,将数据集分成更小的子集,最终形成一个树形结构。每个节点代表一个特征,每个分支代表该特征的一个可能值,每个叶节点代表一个类或回归值。
以下是三种常见的决策树算法及其特征选择标准:
ID3 的目标是构建一棵决策树,用来做分类。它通过不断地挑选“最有用的特征”来划分数据,直到把数据分得“足够清楚”。
ID3 使用的是:信息增益(Information Gain)。
你可以理解为:
哪个特征能让我更“确定”地知道样本属于哪一类,我就先用它来分!
举个通俗的例子:
假设你手里有一堆人,你想判断他们是“喜欢打篮球”还是“不喜欢打篮球”。现在你有几个特征可以看:比如性别、年龄、身高……
所以,“信息增益”就是衡量一个特征能把混乱的数据变多“清晰”的程度。
信息增益 = 分类前的混乱程度(熵) - 按某个特征分后的平均混乱程度
公式是这样的:
别怕,其实意思很简单:
ID3 是一种建树方法,它每次选一个最能“理清楚数据”的特征来切分数据,而它判断哪个特征最厉害的标准,叫做“信息增益”。谁的信息增益大,谁就先上!
C4.5 是在 ID3 的基础上改进而来的,主要解决了 ID3 的一个“小毛病”:
它太偏爱那些取值多的特征!
比如你有两个特征:
ID3 可能会认为“身份证号”这个特征特别有用,因为它可以把数据分得非常细,信息增益很高。但其实这种特征根本没意义,因为每个值都唯一,没法用来泛化预测。
所以,C4.5 就是来纠正这个偏见的。
C4.5 不再直接用“信息增益”,而是用它的“亲戚”——信息增益率(Gain Ratio)。
你可以理解为:
增益率 = “你的贡献有多大” ÷ “你本身有多能分”
这样就能惩罚那些“分得太碎”的特征了。
举个例子:
所以即使它的信息增益高,如果它本身太“分裂”,增益率就会被拉低。
这其实就是计算这个特征本身的“熵”。
也就是说,它衡量的是:这个特征自己能把数据分得多乱或多清晰。
C4.5 是 ID3 的升级版,它不再单纯看哪个特征分得清楚,还要看这个特征是不是“动不动就把数据分得乱七八糟”。它用“信息增益率”来平衡这两方面,避免模型偏爱那些取值很多但没啥实际意义的特征。
CART 是一个“通吃型”选手:
✅ 它可以用来做分类(Classify),
✅ 也可以用来做回归(Regression)。
也就是说,不管是判断“你是男是女”,还是预测“你能考多少分”,它都能搞定!
CART 根据任务不同,使用不同的标准:
任务类型 | 使用的标准 |
---|---|
分类 | 基尼指数(Gini) |
回归 | 均方误差(MSE) |
所以:
基尼指数就是衡量一个数据集有多“混乱”或“不纯”。
每次选一个特征,尝试切一刀,看看能不能把数据分得更“干净”一点。哪个特征能让 Gini 最小,就选谁!
比如你想预测一个人的工资是多少,模型预测了 15k,但实际是 18k,那这个差距就是误差。
MSE 就是平均来看,模型预测值和真实值之间的“平方差”。
每次选一个特征,找一个最合适的切分点,让左右两边的 MSE 最小。也就是说,让预测尽可能接近真实值。
CART 是一个既能分类又能回归的决策树算法。
- 分类的时候看“基尼指数”——谁能让数据分得更干净,就选谁;
- 回归的时候看“均方误差”——谁能让预测更准,就选谁。
如果你想继续了解剪枝、过拟合、随机森林这些内容,我也可以继续用这种通俗方式讲 😄
下面是一个使用Scikit-learn库实现决策树分类器的示例代码:
python展开代码import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.tree import export_text
# 加载数据集(这里使用鸢尾花数据集作为示例)
from sklearn.datasets import load_iris
data = load_iris()
X = data.data
y = data.target
# 拆分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 初始化决策树分类器(使用CART算法)
clf = DecisionTreeClassifier(criterion='gini', random_state=42)
# 训练模型
clf.fit(X_train, y_train)
# 预测
y_pred = clf.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy * 100:.2f}%")
# 导出决策树规则
tree_rules = export_text(clf, feature_names=data.feature_names)
print(tree_rules)
ID3:使用信息增益选择特征。
C4.5:使用信息增益率选择特征,以避免信息增益的偏好。
CART:使用基尼指数(分类)或均方误差(回归)选择特征。
每种决策树算法在不同的应用场景中有不同的优势,选择哪种算法取决于具体任务和数据特征。希望本文能帮助你更好地理解决策树算法及其实现。
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!