
从图像分类到自监督学习我是如何用老熟人交叉熵损失理解InfoNCE的第一次在SimCLR的代码里看到cross_entropy函数时我盯着屏幕愣了三分钟——这明明是个对比学习任务怎么用上了分类任务的损失函数作为常年和ResNet、CrossEntropyLoss打交道的图像分类工程师这种认知冲突让我决定彻底弄明白InfoNCE和交叉熵的隐秘关系。本文将带你用PyTorch代码搭建一座桥梁揭示这两个看似不相关的损失函数如何通过矩阵变换实现神奇的统一。1. 重温交叉熵分类工程师的肌肉记忆在ImageNet上训练分类模型时我们早已对这段代码形成了肌肉记忆# 经典分类任务流程 logits model(images) # [batch_size, num_classes] loss F.cross_entropy(logits, labels) # labels是0-n的类别索引交叉熵的本质是衡量模型预测概率分布与真实分布的差异。以三分类任务为例当label2时理想的logits和计算过程如下logits [1.2, 0.3, 3.0] softmax [0.10, 0.07, 0.83] # 第三个类别概率最高 loss -log(0.83) ≈ 0.19关键点在于logits矩阵batch中每个样本对各类别的预测分数labels向量每个样本对应的真实类别索引计算过程对每个样本取其对应类别概率的负对数这个模式如此深入人心以至于当我们看到对比学习的相似度矩阵被直接喂给cross_entropy时会产生强烈的认知失调。2. 对比学习的障眼法相似度矩阵变形记让我们构造一个最简化的对比学习场景。假设batch_size3经过编码器后得到特征向量q torch.randn(3, 128) # 查询特征 k torch.randn(3, 128) # 关键特征含正样本在SimCLR中第i个样本的正样本就是k[i]其余都是负样本。计算相似度矩阵sim q k.T # 3x3矩阵 假设得到 tensor([[1.0, 0.3, 0.5], [0.2, 0.9, 0.1], [0.4, 0.3, 0.8]]) 现在施展魔法时刻——把这个矩阵看作分类任务的logits对角线元素q[i]与k[i]的相似度正样本对非对角线元素q[i]与k j 的相似度负样本对隐含的labelstorch.arange(3)即让每个q[i]选择第i个位置作为正确类别labels torch.arange(3) # [0, 1, 2] loss F.cross_entropy(sim / temperature, labels)这相当于让模型学习最大化对角线元素正样本相似度最小化非对角线元素负样本相似度3. 温度系数的调音台作用注意到上述代码中的temperature参数了吗这个在原始交叉熵中不存在的超参数正是对比学习的精髓所在temperature 0.07 scaled_sim sim / temperature 当temperature0.07时 原始相似度0.9 → 缩放后12.86 原始相似度0.1 → 缩放后1.43 差异被显著放大 温度系数控制着低温度如0.05加剧样本间差异更容易区分难负样本高温度如0.2软化分布防止模型过早收敛到次优解实验表明temperature0.07时在CIFAR-10上能达到约75%的线性评估准确率而错误配置如0.01或0.5会导致准确率下降5-10个百分点。4. 从理论到实践完整对比学习训练片段结合上述理解我们来看一个完整的训练循环实现class ContrastiveLearner(nn.Module): def __init__(self, backbone): super().__init__() self.encoder backbone self.projection nn.Sequential( nn.Linear(2048, 512), nn.ReLU(), nn.Linear(512, 128) ) def forward(self, x1, x2): # 两个增强视图的特征 h1 self.projection(self.encoder(x1)) h2 self.projection(self.encoder(x2)) # 归一化处理 h1 F.normalize(h1, dim1) h2 F.normalize(h2, dim1) # 计算相似度矩阵 logits h1 h2.T / 0.07 labels torch.arange(len(x1)).to(logits.device) # 对称损失计算 loss (F.cross_entropy(logits, labels) F.cross_entropy(logits.T, labels)) / 2 return loss关键实现细节特征归一化确保相似度在[-1,1]范围内对称损失同时计算h1→h2和h2→h1的对比损失投影头将骨干网络特征映射到更适合对比学习的空间5. 扩展思考对比学习与交叉熵的哲学差异虽然代码实现上使用了相同的函数但两种损失的思想内核存在本质区别维度交叉熵损失InfoNCE损失监督信号来源人工标注的硬标签数据自身的相似性关系概率解释类别预测分布样本间关联程度的度量负样本处理隐式通过softmax分母实现显式构造大量负样本对优化目标预测与标签一致特征空间的几何结构学习这种差异也体现在模型行为上。训练ResNet-50分类模型时最后一层权重向量的余弦相似度通常呈现均匀分布而对比学习模型的特征空间则会自动形成清晰的聚类结构——这正是自监督学习的魅力所在。