使用PyTorch实现SimCLR对比学习

SimCLR (Simple Framework for Contrastive Learning of Visual Representations,即视觉表示对比学习的简单框架) 是一种自监督学习方法,它能够在没有标签数据的情况下学习强大的图像表示。它通过在潜在空间中利用对比损失来最大化同一图像的不同增强视图之间的一致性,从而实现这一目标。通过最大化同一图像的不同增强视图之间的相似度,同时最小化与其他图像的相似度,SimCLR 使模型能够学习到强大的视觉表示。在 PyTorch 中实现 SimCLR 不仅允许我们灵活地进行实验,还能让我们仅使用未标记的数据就在图像任务上获得出色的性能。

SimCLR 的核心思想

  • 数据增强: 我们对每个输入图像进行两次随机增强,以创建两个相关的视图(正样本对)。常见的增强方法包括随机裁剪、翻转、颜色抖动和高斯模糊。
  • 编码器网络: 我们使用深度神经网络(通常是 ResNet-18/50)将每个增强后的图像编码为特征向量。我们将最终的分类层移除,并替换为一个投影头。
  • 投影头: 一个 MLP(多层感知机)将编码器的输出映射到一个低维嵌入空间,在这里应用对比损失。

PyTorch 中的 SimCLR:主要组成部分

1. 数据增强: 让我们定义一组强数据增强策略,为每张图像生成两个不同的视图。

2. 编码器与投影头

  • 使用一个骨干网络(例如 ResNet-18/50),去掉最后的分类层。
  • 添加一个投影头(通常是一个 2 层的 MLP)将特征映射到嵌入空间。
  • 对比损失 (NT-Xent):归一化的温度缩放交叉熵损失,它会鼓励正样本对的嵌入向量相似,而不同图像(负样本)的嵌入向量不相似。

3. 对比损失的实现: 我们需要一个自定义损失函数 (NT-Xent) 来计算批次中每个正样本对的对比损失。
4. 训练循环

  • 对于每个批次,我们为每张图像生成两个增强视图。
  • 将这两个视图都通过编码器和投影头。
  • 计算对比损失并更新模型。

PyTorch 实现

1. 安装依赖库

首先,让我们安装运行模型所需的 PyTorch 相关包。

Python


CODEBLOCK_da230aea

输出结果

!安装过程安装依赖库

环境设置:导入库

让我们导入标准库,包括模型、数据集、转换方法以及辅助工具(例如用于显示进度条的 tqdm)。

Python


CODEBLOCK_06f801ba

2. 数据增强

数据增强旨在为同一图像创建两个不同的视图,以便模型能学习到具有不变性的特征。

Python


CODEBLOCK_d3939833

3. 双视图数据集

对于每一张图像,我们返回两个增强版本:xi 和 xj,它们将作为对比学习中的正样本对。

Python


CODEBLOCK_1ba768a8

4. SimCLR 模型 = 编码器 + 投影头

  • encoder (编码器):一个不含最终分类头的 ResNet18。
  • projection_head (投影头):将特征映射到一个更小的对比空间(128 维)。
  • 输出 z 将用于计算对比损失。

Python


class SimCLRModel(nn.Module):

def init(self, projection_dim=128):

super().init()

base_model = models.resnet18(weights=None)

numftrs = basemodel.fc.in_features

base_model.fc = nn.Identity()

self.encoder = base_model

self.projection_head = nn.Sequential

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。如需转载,请注明文章出处豆丁博客和来源网址。https://shluqu.cn/23840.html
点赞
0.00 平均评分 (0% 分数) - 0