稀疏分类交叉熵是机器学习和深度学习中多类分类问题常用的损失函数,特别适用于处理大量类别的情况。它与分类交叉熵非常相似,但有一个关键区别:真实类别标签以整数(类别索引)的形式提供,而不是独热编码向量。
它专门用于目标标签以整数类别索引(例如 0, 1, 2, …)而非独热编码向量提供的情况。“稀疏”一词指的是这种紧凑的标签表示形式,它避免了将标签转换为冗长的独热编码数组所带来的内存和计算开销。
稀疏分类交叉熵的工作原理
- 对于每个输入样本,模型预测所有类别的概率分布(通常在最后一层使用 softmax 激活函数)。
- 每个样本的真实标签以指定正确类别索引的整数形式给出。
- 稀疏分类交叉熵计算分配给真实类别的预测概率的负对数似然。它仅考虑与实际类别(由整数标签标记)相关的预测概率,而忽略所有其他类别。
- 总体损失通常是批次中所有样本的平均值。
该函数可以定义为:
> L(y, \hat{y}) = -\sum{i=1}^C yi \log(\hat{y}_i)
其中:
- y 是独热编码的真实标签(一个向量),
- \hat y 是所有 C 个类别的预测概率分布。
稀疏分类交叉熵通过直接使用真实类别的整数索引对此进行了修改。每个样本的损失为:
> L(y, \hat{y}) = -\log\left(\hat{y}_y\right)
其中:
- y 是正确类别的整数索引,
- \hat y_y 是模型输出的真实类别的预测概率。
稀疏分类交叉熵的实现
我们将一步步介绍如何在 Python 中实现它:
第一步:导入库
在这里,我们将加载 scikit learn 和 tensorflow 来进行实现。
Python
CODEBLOCK_2c13a683
第二步:加载和预处理数据
**load_iris()**: 加载带有 4 个特征(如花瓣长度、宽度等)的花卉数据。**StandardScaler**:标准化特征值,使其均值为 0,标准差为 1。**train_test_split**:80% 的数据用于训练,20% 用于测试。
Python
CODEBLOCK_43fe6dba
第三步:构建神经网络模型
构建一个神经网络模型,其中包含:
- 第一层:16 个神经元,ReLU 激活函数,输入形状为 4(特征数量)。
- 输出层:3 个神经元(每个花卉类别一个)。
Python
CODEBLOCK_758239ef
第四步:编译和训练模型
- Adam 优化器用于神经网络。
- 这里 keras.losses.SparseCategoricalCrossentropy 调用了稀疏分类交叉熵。
- 在训练数据上训练 20 个轮次。
- 同时跟踪验证(测试)数据上的准确率。
Python
CODEBLOCK_129b1d1d
!trainingTraining
第五步:进行预测
**predict()**: 从最后一层输出原始分数。**softmax**:将 logit 转换为类别概率。**argmax**:选择概率最高的类别。
Python
CODEBLOCK_ad6e3946
输出:
![predictions](https://media.geeksforgeeks.org/wp-content/uploads/2025