Information :表示该损失函数与信息论有关,用于最大化两个相关样本之间的互信息(mutual information)。
Noise-Contrastive Estimation (NCE) :最初由 Gutmann 和 Hyvärinen 提出的一种方法,用于通过区分真实数据和噪声样本来估计概率分布。InfoNCE 是 NCE 的一个变体,专门用于对比学习。
对比学习的核心思想是:
以图文对齐为例,给定图像编码器 和文本编码器 ,对于图文对 :
特征提取:
InfoNCE损失函数:
其中:
让我为你实现一个基于PyTorch的多模态对比学习框架:
pythonimport torch
import torch.nn as nn
import torch.nn.functional as F
class ImageEncoder(nn.Module):
def __init__(self, output_dim):
super().__init__()
# 使用预训练的ResNet作为骨干网络
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# 替换最后的全连接层
self.backbone.fc = nn.Linear(2048, output_dim)
def forward(self, x):
return F.normalize(self.backbone(x), dim=-1)
class TextEncoder(nn.Module):
def __init__(self, vocab_size, embed_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8),
num_layers=6
)
self.fc = nn.Linear(embed_dim, output_dim)
def forward(self, x):
x = self.embedding(x)
x = self.transformer(x)
# 取[CLS]token的输出
x = x.mean(dim=1)
x = self.fc(x)
return F.normalize(x, dim=-1)
class MultimodalContrastive(nn.Module):
def __init__(self, image_encoder, text_encoder, temperature=0.07):
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.temperature = temperature
def forward(self, images, texts):
# 获取图像和文本的特征表示
image_features = self.image_encoder(images)
text_features = self.text_encoder(texts)
# 计算相似度矩阵
logits = torch.matmul(image_features, text_features.t()) / self.temperature
# 创建标签(对角线为正样本)
labels = torch.arange(len(images)).to(images.device)
# 计算双向的对比损失
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.t(), labels)
return (loss_i2t + loss_t2i) / 2
def train_step(model, optimizer, images, texts):
optimizer.zero_grad()
loss = model(images, texts)
loss.backward()
optimizer.step()
return loss.item()
# 使用示例
def main():
# 超参数设置
output_dim = 512
vocab_size = 30000
embed_dim = 768
# 初始化模型
image_encoder = ImageEncoder(output_dim)
text_encoder = TextEncoder(vocab_size, embed_dim, output_dim)
model = MultimodalContrastive(image_encoder, text_encoder)
# 优化器设置
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 训练循环
# for epoch in range(num_epochs):
# for images, texts in dataloader:
# loss = train_step(model, optimizer, images, texts)
if __name__ == "__main__":
main()
多模态对齐任务的数据通常包含配对的跨模态数据,以下是常见的数据形式:
python{
'image': torch.Tensor, # 形状: [batch_size, 3, height, width]
'text': torch.Tensor, # 形状: [batch_size, seq_length]
'label': torch.Tensor # 可选的监督信息
}
应用领域 | 输入模态组合 | 主要用途 | 代表模型/方法 |
---|---|---|---|
图文检索 | 图像 + 文本 | 跨模态检索 | CLIP, ALIGN |
自监督视觉学习 | 图像 + 增强图像 | 表示学习 | SimCLR, MoCo |
视频理解 | 视频 + 文本 | 描述生成、检索 | HERO, VideoBERT |
推荐系统 | 商品图像/文本 + 用户行为 | 多模态推荐 | MMRec, Graph Neural Networks |
医疗图像分析 | 医学图像 + 报告 | 辅助诊断 | MedCLIP, CheXNeXt |
语音处理 | 音频 + 文本 | 跨模态理解 | Wav2Vec 2.0 + CLIP |
零样本学习 | 图像 + 类别名称 | 开集识别 | ZeroCLIP |
机器人导航 | 图像 + 指令文本 | 视觉语言导航 | ALFRED, Room-to-Room |
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!