在这个数据驱动的时代,我们通常认为构建一个高性能的人工智能模型需要成千上万甚至上亿的标记数据。然而,这种“大数据依赖症”在现实世界面前往往碰壁。想象一下,在医疗诊断中,某种罕见病全球只有寥寥几个确诊病例;或者在个性化手机解锁中,我们不可能要求用户录入几百次人脸。这就像让一个只看过一次熊猫的孩子,下次在动物园能一眼认出它一样。这就是我们今天要探讨的核心技术——少样本学习(Few-Shot Learning,简称 FSL)。在本文中,我们将深入探讨 FSL 的工作原理,并通过基于 PyTorch 的实际代码,教你如何利用预训练模型实现一个高效的少样本分类器。
为什么我们需要少样本学习?
传统的监督学习算法就像一个需要死记硬背的学生,只有通过大量题海战术(海量数据)才能在考试(测试)中取得好成绩。一旦遇到没见过的新题型(新类别),它们往往束手无策。相比之下,少样本学习旨在模拟人类的认知过程:利用已有的先验知识,通过极少量的样本迅速掌握新概念。
这种方法不仅让 AI 更加高效,也极大地降低了数据标注的成本。在数据采集极其昂贵或数据极度稀缺的领域,例如罕见疾病检测、野生动物保护或实时个性化推荐系统中,FSL 展现出了无可替代的价值。
核心优势一览
- 降低数据门槛:不再依赖海量数据,仅需少量样本即可完成模型训练或适配。
- 快速迭代:使 AI 能够像人类一样,通过几个示例就迅速识别新物体或新面孔。
- 节约成本:大幅减少了对大规模人工标注数据集的需求,节省了宝贵的人力和时间资源。
少样本学习是如何工作的?
在深入研究代码之前,我们需要先理解 FSL 背后的核心机制。通常,我们将每个学习任务划分为两个关键部分:支持集 和 查询集。这两个集合在帮助模型从少量标记示例中学习并泛化到新的未见样本方面发挥着至关重要的作用。
1. 支持集 (Support Set)
支持集就是我们的“教科书”。它包含了一小部分带有标记的示例,通常表示为:
$$S = \{(x1, y1), (x2, y2), \dots, (xk, yk)\}$$
其中,$xi$ 代表数据点(比如一张图片),$yi$ 是该数据点对应的标签。在 FSL 中,$k$ 的值通常非常小,例如 1-shot(每个类只有1张图)或 5-shot(每个类只有5张图)。
2. 查询集 (Query Set)
查询集就像是“考试题”。它包含未标记的示例,表示为:
$$Q = \{x‘1, x‘2, \dots, x‘_m\}$$
我们的目标是让模型基于从支持集上学到的知识,给这些查询样本打上正确的标签。
3. 基于度量的学习
目前的少样本学习方法大多基于度量学习。其核心思想非常直观:如果我们无法通过少量样本训练一个复杂的分类器,那我们就直接比较“查询图片”和“支持集图片”谁长得更像。
模型将把查询集中的样本与支持集中的样本进行比较,计算它们之间的相似度。常见的度量标准包括:
- 余弦相似度:关注特征向量方向的夹角。
- 欧几里得距离:关注特征向量在空间中的直线距离。
模型会根据这些度量标准,找到与查询示例最相似的支持示例,并将其标签分配给查询样本。
动手实践:构建你的第一个 FSL 模型
让我们通过一个具体的例子来看看这一切是如何运作的。我们将使用经典的 CIFAR-10 数据集,并使用强大的 ResNet50 预训练模型作为特征提取器。我们的目标是从 CIFAR-10 中选取几个类别,每个类别只取极少量图片作为支持集,看看模型能否识别出同类的其他图片。
步骤 1:环境准备与工具导入
首先,我们需要搭建好“武器库”。我们将使用 timm(PyTorch Image Models)库来加载预训练模型,这是一个非常高效的工具。
# 安装必要的库
!pip install -q timm torch torchvision
import random
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import CIFAR10
import timm
# 为了实验的可复现性,我们设置随机种子
# 这确保了我们每次运行代码得到的结果是一致的
random.seed(42)
torch.manual_seed(42)
# 检查是否有可用的 GPU,这会大大加速我们的计算
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
步骤 2:加载特征提取器
在少样本学习中,我们通常不会从头开始训练模型(因为数据太少)。相反,我们会利用在大规模数据集(如 ImageNet)上预训练好的模型。这些模型已经学会了如何提取图像的底层和高层特征(边缘、纹理、形状等),我们可以直接复用这些能力。
我们将使用 ResNet50,但要去掉它的最后一层分类头,只保留特征提取部分。
# 加载 ResNet50 模型
# num_classes=0 表示我们去掉最后的全连接层,只保留特征提取器
# pretrained=True 表示使用在 ImageNet 上预训练好的权重
model = timm.create_model("resnet50", pretrained=True, num_classes=0)
# 将模型移动到 GPU 并设置为评估模式
# .eval() 会关闭 Dropout 等训练时特有的操作,保证输出稳定
model = model.eval().to(device)
步骤 3:数据预处理流水线
预训练模型通常对输入图像的大小和统计分布有特定要求。ResNet50 需要 $224 \times 224$ 的图像输入,并且需要使用 ImageNet 的均值和标准差进行归一化。
# 定义图像转换管道
# 1. Resize: 将较小的 CIFAR 图片放大到 224x224
# 2. CenterCrop: 确保我们获得正方形裁剪
# 3. ToTensor: 将图片转换为 PyTorch 张量
# 4. Normalize: 标准化数据,使其符合预训练模型的输入分布
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet 的均值
std=[0.229, 0.224, 0.225] # ImageNet 的标准差
),
])
步骤 4:构建支持集与查询集
这是实验的关键部分。我们需要从 CIFAR-10 数据集中手动构造一个“少样本”场景。假设我们选择了“猫”、“狗”和“飞机”三个类别,每个类别我们只取 5 张图片作为支持集(这就是 5-shot 学习)。然后,我们再取另外一些图片作为查询集来测试准确率。
# 加载完整的 CIFAR-10 数据集
# train=False 通常意味着测试集,但这里我们只用来取数据,不涉及真实训练过程
full_dataset = CIFAR10(root="./data", train=False, download=True, transform=transform)
# 定义我们要演示的类别索引
# CIFAR-10 标签: 0:飞机, 1:汽车, 2:鸟, 3:猫, 4:鹿, 5:狗, 6:青蛙, 7:马, 8:船, 9:卡车
selected_classes = [3, 5, 0] # 比如我们要做:猫、狗、飞机 的分类
num_shots = 5 # 每个类取 5 张图作为支持集
num_queries_per_class = 10 # 每个类取 10 张图作为查询集
support_images = []
support_labels = []
query_images = []
query_labels = []
# 遍历数据集,筛选出我们需要的类别
for img, label in full_dataset:
if label in selected_classes:
# 为了代码简洁,这里我们做一个简单的逻辑:
# 先收集所有该类别的图片,然后手动切片
# 在实际工程中,通常会有更严谨的数据加载器
pass
# 为了演示代码的清晰性,我们手动构建索引
# 注意:下面的代码逻辑是将所有目标类的图片找出来,然后切片
class_imgs = {c: [] for c in selected_classes}
for img, label in full_dataset:
if label in selected_classes:
class_imgs[label].append(img)
# 构建最终的 Support Set 和 Query Set
new_label_map = {old: i for i, old in enumerate(selected_classes)} # 重新映射标签为 0, 1, 2
for old_label, imgs in class_imgs.items():
# 前5张作为支持集,后面的前10张作为查询集
support_images.extend(imgs[:num_shots])
query_images.extend(imgs[num_shots : num_shots + num_queries_per_class])
support_labels.extend([new_label_map[old_label]] * num_shots)
query_labels.extend([new_label_map[old_label]] * num_queries_per_class)
# 转换为 Tensor 格式以便输入模型
def collate_fn(images):
# 将一个列表的图片堆叠成一个 batch
return torch.stack(images, dim=0)
support_tensor = collate_fn(support_images).to(device)
query_tensor = collate_fn(query_images).to(device)
print(f"Support set shape: {support_tensor.shape}") # 应该是 (15, 3, 224, 224) -> 3类*5张
print(f"Query set shape: {query_tensor.shape}") # 应该是 (30, 3, 224, 224) -> 3类*10张
步骤 5:计算特征与相似度
现在,我们将图片输入模型,提取高维特征向量。然后,计算查询集特征与支持集特征之间的余弦相似度。
# 使用 torch.no_grad() 因为我们在做推理,不需要计算梯度
# 这能节省大量显存
with torch.no_grad():
# 1. 提取特征
# 输出形状: (Batch_Size, Feature_Dim)
# 对于 ResNet50,Feature_Dim 通常是 2048
support_features = model(support_tensor)
query_features = model(query_tensor)
# 2. 归一化特征向量
# 计算余弦相似度前,必须先将向量归一化(单位向量)
# 这样相似度就变成了向量点积
support_features = F.normalize(support_features, p=2, dim=1)
query_features = F.normalize(query_features, p=2, dim=1)
# 3. 计算相似度矩阵
# query_features: (30, 2048)
# support_features.T: (2048, 15)
# similarities: (30, 15) -> 每个查询样本对所有支持样本的相似度
similarities = torch.mm(query_features, support_features.T)
print("相似度矩阵形状:", similarities.shape)
步骤 6:预测与评估
最后一步,我们根据相似度矩阵找出每个查询样本最像哪一张支持集图片。
# 1. 找到每个查询样本在支持集中最相似的样本索引
# best_match_indices 形状: (30,),存储的是在 support_tensor 中的索引
best_match_indices = torch.argmax(similarities, dim=1)
# 2. 根据索引映射回标签
# 我们需要根据索引去 support_labels 中找对应的标签
# 为了方便,先把 support_labels 转成 tensor
support_labels_tensor = torch.tensor(support_labels).to(device)
predicted_labels = support_labels_tensor[best_match_indices]
# 3. 计算准确率
true_labels_tensor = torch.tensor(query_labels).to(device)
accuracy = (predicted_labels == true_labels_tensor).float().mean().item()
print(f"Few-Shot 分类准确率: {accuracy * 100:.2f}%")
深入解析:如何进一步提升性能?
通过上面的代码,我们实现了一个最基础的基准线。但在实际应用中,你可能会遇到准确率波动或不够理想的情况。这里有一些经验丰富的开发者常用的优化策略:
1. 数据增强是关键
在少样本场景下,数据极其珍贵。如果我们每个类只有 1 张或 5 张图片,直接使用可能会过拟合。我们可以对支持集图片进行增强(如随机裁剪、旋转、颜色抖动),这在某种程度上相当于增加了支持集的数量。
2. 优化距离度量
我们在这里使用了简单的余弦相似度。但在复杂的任务中,简单的欧氏距离或余弦相似度可能不足以区分细粒度的差异(例如区分两种非常相似的鸟)。更高级的算法会训练一个元学习器来学习如何度量距离。
3. 微调特征提取器
虽然 ResNet50 在 ImageNet 上很强,但它未必见过 CIFAR-10 的这种特定画风。在某些情况下,我们可以对特征提取器进行微调,让它更好地适应目标域的数据分布。
常见错误与解决方案
在实践过程中,初学者常会掉进一些坑里,这里为你列举几个最典型的:
- 忘记归一化:这是最常见的一个错误。如果你没有使用
F.normalize将特征向量转换为模长为1的向量,直接计算点积将不再是余弦相似度,而是向量内积。这会导致模长较大的特征向量占据主导地位,严重影响分类效果。
解决方案*:如上文代码所示,在计算 mm 之前务必归一化。
- 预处理不匹配:预训练模型对输入非常敏感。如果你只是 Resize 了图片但忘记了 Normalize(或者均值填错了),模型的特征提取能力会大打折扣,导致准确率断崖式下跌。
解决方案*:严格按照 timm 模型对应的配置(通常是 ImageNet 的均值和标准差)进行预处理。
- 内存溢出:如果你尝试一次性将整个支持集和查询集放入内存,可能会遇到 OOM(Out of Memory)错误,特别是图片分辨率很大时。
解决方案*:使用 batch_size 分批处理特征提取过程。
结语
少样本学习是通往更通用人工智能的重要一步。通过本文,我们从基本概念出发,理解了支持集与查询集的协同工作方式,并利用 PyTorch 和预训练的 ResNet50 亲手构建了一个基础的少样本分类系统。虽然我们只使用了简单的余弦相似度,但这也展示了迁移学习在小样本场景下的巨大潜力。
接下来,建议你尝试更换不同的预训练模型(如 EfficientNet),或者尝试实现 Prototypical Networks(原型网络),通过计算类别中心来进行分类,这通常是比单纯计算 1-to-1 相似度更稳定的方法。希望你能在这个过程中,找到属于自己的灵感!