LLM 蒸馏(LLM Distillation)是知识蒸馏(KD)的一种专门形式,它旨在将大规模 LLM 压缩成更小、更快、更高效的模型,同时尽可能保留其大部分性能。这使得轻量级模型能够近似具备巨大 LLM 的能力,从而可以在更广泛的应用和设备上进行部署。让我们来看看 LLM 蒸馏中的几个关键术语:
- 知识转移(Knowledge Transfer):将学习到的知识从大型教师模型转移到较小的学生模型。
- 教师模型(Teacher Model):一个大型的预训练 LLM,在蒸馏过程中指导学生模型。
- 学生模型(Student Model):一个更小、更高效的模型,经过训练以模仿教师模型的输出。
- 软标签(Soft Labels):使用来自教师的概率分布,而不是硬分类标签,以传递更丰富的信息。
- KL 散度(KL Divergence):一种损失函数,用于衡量教师和学生输出分布之间的差异。
- 推理效率(Inference Efficiency):经过蒸馏的模型需要更少的计算量,能够以更低的延迟实现更快的预测。
- 特征匹配(Feature Matching):除了输出逻辑之外,还对齐教师和学生之间的内部表示。
蒸馏技术
为了将知识从教师模型转移到学生模型,我们会使用各种蒸馏技术。这些方法确保了学生模型不仅能高效学习,还能保留教师模型的基本知识和能力。以下是在 LLM 蒸馏中使用的一些最突出的技术。
1. 知识蒸馏(Knowledge Distillation)
知识蒸馏(KD)是 LLM 蒸馏中最广泛使用的技术之一。在这种方法中,学生模型是使用教师模型的输出概率(被称为软目标 Soft targets)以及真实标签(被称为硬目标 Hard targets)来训练的。软目标提供了教师预测的更丰富视角,允许学生捕捉编码在教师模型中的微妙模式和复杂知识。这使学生能够更好地理解教师的决策过程,在保留基本知识的同时提高准确性和可靠性。
- 软目标提供的是可能输出的概率分布,而不是单一的正确答案。
- 帮助学生模型捕捉复杂的模式和细微的知识。
- 从而带来更准确、更可靠的学生模型表现。
- 通过保留关键的教师知识,促进更平滑、更有效的训练。
!<a href="https://media.geeksforgeeks.org/wp-content/uploads/20250211234016726084/What-is-LLM-Distillation.webp">What-is-LLM-Distillation知识蒸馏框架
此外,还有其他几种技术也被用来增强 LLM 蒸馏的效果:
- 数据增强(Data Augmentation):这涉及使用教师模型生成额外的训练数据。通过扩展数据集,学生能接触到更多样化的场景,从而提高泛化能力和鲁棒性。
- 中间层蒸馏(Intermediate Layer Distillation):这种方法不仅仅关注最终的输出,而是从教师模型的中间层转移知识。从这些中间表示中学习,使学生能够捕捉更详细和结构化的信息,从而提升整体性能。
- 多教师蒸馏(Multi-Teacher Distillation):学生模型可以同时从多个教师模型学习。通过聚合来自不同教师的知识,有助于学生通过整合不同的视角,获得更全面的理解和更强的鲁棒性。
- 基于特征的蒸馏(Feature-Based Distillation):学生模仿教师的中间隐藏层表示。这是通过最小化相应内部激活值之间的差异(例如 L2 损失)来实现的。
- 提示词蒸馏(Prompt Distillation):这种技术将长而复杂的提示词压缩成简短、高效的提示词,同时保持其有效性。通过捕捉提示词的核心意图,它减少了计算量并加快了推理速度。
- 基于强化学习(RL)的蒸馏(Reinforcement Learning (RL) Based Distillation):使用教师反馈作为奖励,利用强化学习方法迭代改进学生模型的输出。
- 特定任务蒸馏(Task-Specific Distillation):在蒸馏后,学生模型会在特定的下游任务(例如情感分析、摘要生成)上进行微调,以提高在实际应用中的性能。
LLM 蒸馏是如何工作的?
让我们来看看 LLM 蒸馏的工作原理。
步骤 1:导入库
我们需要为模型导入必要的模块和库:
- torch:主要的 PyTorch库,用于张量运算和自动微分。
- torch.nn as nn:提供神经网络构建块,包括层类型和模块。
- torch.optim as optim::包含用于训练的优化算法,如 Adam。