编辑
2025-05-09
深度学习
00

目录

1.为什么叫InfoNCE损失
2. 对比学习原理
2.1 数学公式
3. 代码实现
4. 数据形式
4.1 图文对齐数据
5. 应用领域

1.为什么叫InfoNCE损失

Information :表示该损失函数与信息论有关,用于最大化两个相关样本之间的互信息(mutual information)。

Noise-Contrastive Estimation (NCE) :最初由 Gutmann 和 Hyvärinen 提出的一种方法,用于通过区分真实数据和噪声样本来估计概率分布。InfoNCE 是 NCE 的一个变体,专门用于对比学习。

2. 对比学习原理

对比学习的核心思想是:

  • 最大化正样本对(相关数据)之间的相似度
  • 最小化负样本对(不相关数据)之间的相似度

2.1 数学公式

以图文对齐为例,给定图像编码器 fvf_v 和文本编码器 ftf_t,对于图文对 (v,t)(v,t):

  1. 特征提取: hv=fv(v),ht=ft(t)h_v = f_v(v), h_t = f_t(t)

  2. InfoNCE损失函数: L=logexp(sim(hv,ht)/τ)texp(sim(hv,ht)/τ)\mathcal{L} = -\log \frac{\exp(sim(h_v, h_t)/\tau)}{\sum_{t'} \exp(sim(h_v, h_{t'})/\tau)}

其中:

  • sim()sim() 是余弦相似度函数
  • τ\tau 是温度参数
  • tt' 表示batch中的负样本

3. 代码实现

让我为你实现一个基于PyTorch的多模态对比学习框架:

python
import torch import torch.nn as nn import torch.nn.functional as F class ImageEncoder(nn.Module): def __init__(self, output_dim): super().__init__() # 使用预训练的ResNet作为骨干网络 self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) # 替换最后的全连接层 self.backbone.fc = nn.Linear(2048, output_dim) def forward(self, x): return F.normalize(self.backbone(x), dim=-1) class TextEncoder(nn.Module): def __init__(self, vocab_size, embed_dim, output_dim): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8), num_layers=6 ) self.fc = nn.Linear(embed_dim, output_dim) def forward(self, x): x = self.embedding(x) x = self.transformer(x) # 取[CLS]token的输出 x = x.mean(dim=1) x = self.fc(x) return F.normalize(x, dim=-1) class MultimodalContrastive(nn.Module): def __init__(self, image_encoder, text_encoder, temperature=0.07): super().__init__() self.image_encoder = image_encoder self.text_encoder = text_encoder self.temperature = temperature def forward(self, images, texts): # 获取图像和文本的特征表示 image_features = self.image_encoder(images) text_features = self.text_encoder(texts) # 计算相似度矩阵 logits = torch.matmul(image_features, text_features.t()) / self.temperature # 创建标签(对角线为正样本) labels = torch.arange(len(images)).to(images.device) # 计算双向的对比损失 loss_i2t = F.cross_entropy(logits, labels) loss_t2i = F.cross_entropy(logits.t(), labels) return (loss_i2t + loss_t2i) / 2 def train_step(model, optimizer, images, texts): optimizer.zero_grad() loss = model(images, texts) loss.backward() optimizer.step() return loss.item() # 使用示例 def main(): # 超参数设置 output_dim = 512 vocab_size = 30000 embed_dim = 768 # 初始化模型 image_encoder = ImageEncoder(output_dim) text_encoder = TextEncoder(vocab_size, embed_dim, output_dim) model = MultimodalContrastive(image_encoder, text_encoder) # 优化器设置 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # 训练循环 # for epoch in range(num_epochs): # for images, texts in dataloader: # loss = train_step(model, optimizer, images, texts) if __name__ == "__main__": main()

4. 数据形式

多模态对齐任务的数据通常包含配对的跨模态数据,以下是常见的数据形式:

4.1 图文对齐数据

python
{ 'image': torch.Tensor, # 形状: [batch_size, 3, height, width] 'text': torch.Tensor, # 形状: [batch_size, seq_length] 'label': torch.Tensor # 可选的监督信息 }

5. 应用领域

应用领域输入模态组合主要用途代表模型/方法
图文检索图像 + 文本跨模态检索CLIP, ALIGN
自监督视觉学习图像 + 增强图像表示学习SimCLR, MoCo
视频理解视频 + 文本描述生成、检索HERO, VideoBERT
推荐系统商品图像/文本 + 用户行为多模态推荐MMRec, Graph Neural Networks
医疗图像分析医学图像 + 报告辅助诊断MedCLIP, CheXNeXt
语音处理音频 + 文本跨模态理解Wav2Vec 2.0 + CLIP
零样本学习图像 + 类别名称开集识别ZeroCLIP
机器人导航图像 + 指令文本视觉语言导航ALFRED, Room-to-Room
如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!