度量学习在病理图像分类中的应用:构建可解释的AI诊断模型

发布时间:2026/5/26 14:25:33

度量学习在病理图像分类中的应用:构建可解释的AI诊断模型 1. 项目概述当病理图像分类遇上度量学习在病理科医生的日常工作中诊断一张组织切片Whole Slide Image, WSI是一项极其耗时且需要高度专业知识的任务。一张高分辨率的WSI可能包含数十亿像素直接使用传统的深度学习模型进行端到端分类不仅计算资源消耗巨大更关键的是模型往往像一个“黑箱”医生无法理解其做出“恶性”或“良性”判断的依据是什么。这直接阻碍了人工智能在临床辅助诊断中的深度应用和信任建立。近年来度量学习Metric Learning为我们打开了一扇新窗。它的核心思想并非直接学习一个“是或否”的分类边界而是学习一个“好”的特征空间。在这个空间里所有“乳腺导管癌”的图像特征向量彼此靠近所有“良性腺瘤”的图像特征向量也彼此靠近但这两类特征向量簇之间则保持足够的距离。孪生网络Siamese Network和三元组网络Triplet Network是实现这一目标的两种经典架构。它们通过对比学习Contrastive Learning的方式驱动网络学习到这种具有高度判别性的嵌入Embedding。本项目的实践正是将这一套思路应用于病理图像分类。我们构建了一个融合了深度特征提取、度量学习和可解释分类器的完整框架。简单来说我们先用一个预训练好的ResNet152网络作为强大的“特征提取器”将病理图像块转换为2048维的高维特征。然后我们并非直接在这些特征上接一个复杂的多层感知机MLP进行分类而是插入一个经过度量学习训练的线性嵌入层将特征压缩到一个更具判别性的512维空间中。最后在这个清晰、可分的新空间里我们使用支持向量机SVM或k近邻k-NN这类原理简单、决策透明的“白盒”分类器来完成最终判定。这套方案的价值是双重的第一是性能我们在多个公开病理数据集上达到了与最先进SOTA复杂模型相媲美甚至更优的分类精度第二是可解释性我们可以将512维的嵌入通过主成分分析PCA降维到2D或3D进行可视化医生能直观地看到待诊断的图像块落在特征空间的哪个位置周围是哪些已知的样本从而理解模型决策的“几何依据”。此外我们还为分类结果输出了一个置信度分数告诉医生这个判断有多大的把握为临床决策提供了宝贵的参考维度。2. 核心思路拆解为何是“特征嵌入简单分类器”2.1 传统深度分类模型的局限在深入我们的方案之前有必要先理解主流方案的痛点。一个典型的深度卷积神经网络CNN用于图像分类其架构通常是“卷积特征提取器 全连接分类器”的端到端模式。这种模式存在几个问题特征与分类耦合过紧网络学习到的特征表示是为最终的softmax分类头“量身定制”的这些特征可能为了优化交叉熵损失而丢失了样本间的相对关系信息导致特征空间的结构不清晰。可解释性差最终的全连接层是一个复杂的非线性变换矩阵其决策过程难以追溯。我们只知道输入图片A被分到了类别1但不知道它“像”训练集中的哪些样本或者它距离决策边界有多远。对数据划分敏感特别是在医学图像领域如果训练集和测试集包含了同一患者的不同图像块即基于图像块的分割而非基于患者的分割模型可能会学到患者特有的生物标记如染色差异、组织切片制备痕迹而非疾病本身的特征导致在未见过的患者数据上泛化能力骤降。2.2 度量学习的嵌入空间从“分类”到“度量”度量学习的核心目标是学习一个映射函数 f(·)将原始数据 x 映射到嵌入空间 z f(x)使得在这个新空间中相似样本的欧氏距离或其他距离度量小不相似样本的距离大。孪生网络它接受一对图像x_i, x_j作为输入共享同一个权重网络。通过一个对比损失函数Contrastive Loss进行训练。损失函数鼓励网络输出使得同类样本对的嵌入距离变小异类样本对的嵌入距离变大直到超过一个预设的边界值margin τ。其数学形式简洁有力对于一对样本损失为(1 - z) * D^2 z * max(0, τ - D)^2其中z为相似性标签0为同类1为异类D为嵌入距离。三元组网络它接受一个三元组锚点样本 x_a, 正样本 x_p, 负样本 x_n作为输入。通过三元组损失Triplet Loss进行训练目标是让锚点到正样本的距离比锚点到负样本的距离至少小一个边界值 m。损失函数为max(0, D_ap - D_an m)。这里的关键在于“三元组挖掘”策略我们采用了半困难负样本挖掘选择那些负样本比正样本离锚点远但距离差仍在边界值m以内的三元组。这种策略能提供稳定且信息量大的梯度避免使用过于简单已满足约束或过于困难导致训练不稳定的样本。实操心得边界值Margin的选择边界值τ或m是一个关键超参数。设置太小模型可能无法充分拉开不同类别的样本设置太大可能导致训练困难甚至不收敛。在我们的初步实验中见图5我们发现对于病理图像τ1.0是一个稳健的起点。这个值在BreaKHis和Kather数据集上都能取得稳定且良好的性能避免了因边界值过小导致的欠拟合或过大导致的过拟合风险。2.3 简单分类器的回归信任源于透明在得到了一个“聚类友好”的嵌入空间后复杂的非线性分类器如深度MLP就不再是必需品。相反我们可以回归到SVM或k-NN这类模型支持向量机在嵌入空间中寻找一个最优的线性超平面来分割不同类别。其决策函数f(x) sign(w·x b)具有清晰的几何解释。样本点到超平面的距离可以直接转化为置信度分数。例如对于一个二分类问题我们可以使用CS(x) 1 / (1 exp(-d(x, π)))其中d是样本到超平面的符号距离。距离越远置信度越接近1或0距离越近靠近决策边界置信度越接近0.5表示模型“不确定”。k近邻直接根据待测样本在嵌入空间中的k个最近邻的类别进行投票。其置信度可以简单地定义为CS(x) (同类近邻数) / k。例如如果k5有4个邻居是“恶性”那么模型预测“恶性”的置信度就是0.8。这两种方法的共同点是决策过程透明。对于SVM医生可以理解“因为该样本落在超平面的‘恶性’一侧”对于k-NN医生甚至可以查看那k个最相似的训练样本进行类比诊断这本身就是一种强大的可解释性工具。3. 系统架构与实现细节3.1 整体流程与模块设计我们的系统是一个清晰的三阶段管道如下图所示意[输入病理图像块] - [特征提取器 (ResNet152)] - [度量学习嵌入层 (512维)] - [可解释分类器 (SVM/k-NN)] - [输出类别标签 置信度] (冻结权重预训练) (唯一训练部分对比损失) (在嵌入空间上训练)第一阶段深度特征提取骨干网络采用在ImageNet上预训练的ResNet152。选择ResNet系列是因为其残差结构能有效缓解深层网络的梯度消失问题且152层的深度足以从病理图像中提取丰富的中高层特征如纹理、结构模式。为什么冻结权重我们冻结了ResNet152的所有层不进行微调。这基于一个关键假设在ImageNet上学习到的通用视觉特征边缘、纹理、形状对于病理图像分析仍然是高度有效的。这样做带来了两大好处1) 极大减少训练参数量和计算成本2) 避免在数据量有限的医学图像上对复杂主干网络进行微调可能导致的过拟合。实验也证明仅训练最后的嵌入层就足以取得优异性能。第二阶段度量学习嵌入嵌入层一个简单的全连接线性层输入为ResNet152输出的2048维特征输出为512维的嵌入向量。随后对嵌入向量进行L2归一化将其约束在一个超球面上这有利于基于距离的度量学习。训练配置优化器RMSProp学习率设为1e-5权重衰减为1e-4。较小的学习率适合微调最后一层。批次大小32。训练轮数20个epoch。实验表明度量学习收敛较快。损失函数对于孪生网络使用对比损失对于三元组网络使用带半困难挖掘的三元组损失。关键技巧在构建训练对孪生或三元组三元组时必须在批次内进行在线挖掘动态地选择最有信息量的样本对这比使用固定的预计算对效果更好。第三阶段可解释分类与可视化分类器训练在训练集生成的嵌入上独立训练一个线性SVM或k-NNk5分类器。这个过程极快。可视化使用主成分分析将512维的嵌入降维至2维绘制散点图。这张图是向医生解释模型工作的“仪表盘”可以清晰展示不同类别的聚类情况以及新样本在空间中的落点。3.2 数据集处理与实验设置我们在四个具有代表性的公开病理数据集上验证了框架的有效性涵盖了二分类、多分类和分级任务数据集任务描述类别图像数量关键特点BreaKHis乳腺肿瘤良恶性判别2 (良性/恶性)7909提供患者ID支持基于患者的数据划分能评估真实泛化能力。Kather结直肠癌组织类型识别8 (肿瘤、间质等)5000八类组织平衡但无患者信息评估多分类性能。PathoIDCG乳腺癌组织学分级3 (Grade 1/2/3)3644分级任务分辨率高。Agios Pavlos乳腺癌组织学分级3 (Grade 1/2/3)300数据量小但提供患者信息可进行严格的基于患者的分割。注意事项基于患者 vs 基于图像的分割这是医学图像分析中一个至关重要但常被忽视的细节。基于图像的分割是随机打乱所有图像块划分训练/测试集这可能导致同一患者的图像同时出现在训练集和测试集。模型可能学会识别“这个患者的染色特点”而非“癌症的形态学特点”导致虚高的准确率。基于患者的分割则确保训练集和测试集的患者完全互斥这更符合临床实际诊断新患者但通常会导致模型性能下降。我们的实验明确对比了这两种设置结果见表1显示在BreaKHis上基于患者分割的准确率比基于图像分割低约11个百分点。这警示我们在评估医学AI模型时必须关注其数据划分策略。4. 实验结果深度分析4.1 性能对比不逊于复杂模型我们在各数据集上进行了五折或十折交叉验证并与现有文献中的先进方法进行了对比。下表汇总了我们的核心结果使用三元组网络线性SVM数据集任务我们的方法 (PLA/Accuracy)对比的SOTA方法 (PLA/Accuracy)关键结论BreaKHis(患者分割)良恶性二分类87.6% (PLA)Bayramoglu et al.: 82.0%Sun et al.: ~85%Gupta et al.: 87.5%我们的方法达到了顶尖水平且模型结构更简单透明。Kather组织八分类95.1% (Accuracy)Kather et al. (手工特征): 87.4%Ohata et al. (DenseNetSVM): 92.8%Zeid et al. (Vision Transformer): 94.7%显著优于传统方法与基于Transformer的复杂模型性能相当甚至略优。PathoIDCG肿瘤三级分级96.7% (Accuracy)Yan et al. (NGNet): 93.4%Yan et al. (DANet): 91.6%Senousy et al. (3E-Net): 99.5%*性能优异。*注99.5%的方法使用了复杂集成和不确定性丢弃策略。Agios Pavlos(图像分割)肿瘤三级分级93.3% (Accuracy)Nanni et al. (ResNet50集成): 94.3%Dimitropoulos et al. (手工特征SVM): 95.8%与复杂集成模型性能接近。Agios Pavlos(患者分割)肿瘤三级分级66.7% (PLA)文献中通常未报告此严格设置下的结果。凸显了在小样本、患者独立设置下的真实挑战也是未来改进方向。消融实验的强力证明为了验证嵌入层的核心作用我们进行了关键的消融实验直接使用ResNet152提取的2048维特征接上相同的线性SVM进行分类。结果发现在Kather数据集上准确率从95.1%暴跌至80%以下在PathoIDCG和Agios Pavlos数据集上也有超过30个百分点的巨大差距。这 unequivocally 证明度量学习嵌入层是性能提升的关键它重塑了特征空间使其对简单分类器变得“友好”。4.2 可视化分析看见模型的“思考”过程可视化是解释性的核心。下图展示了在BreaKHis数据集上使用三元组网络得到的嵌入空间2D投影基于患者分割。此处为文字描述实际博文应嵌入生成的散点图 左侧训练集良性蓝色和恶性红色样本形成了两个相对紧凑、分离的簇。 右侧测试集新患者的样本被映射到该空间。大部分样本仍能落入对应的簇内但可以看到有部分红色恶性样本点落在了蓝色良性簇的边缘甚至内部分类错误点。同时两个簇的边界区域存在一些交错这直观地解释了为什么基于患者分割的准确率会下降——新患者的特征分布与训练集存在差异。通过这样的可视化病理学家可以评估模型可靠性如果某个类别的训练样本在嵌入空间中本身就分散、与其他类别重叠那么模型对该类别的判断天生就不可靠。理解误判当一个新样本被误分类时医生可以查看它在空间中的位置。如果它落在两个簇的边界模糊区域那么这个误判是“情有可原”的如果它深陷对方簇的内部则可能提示模型存在更严重的问题或者该样本本身是罕见/困难的病例。发现数据问题可视化可能揭示数据标注的不一致例如某个明显离群的样本可能需要重新复核其标签。4.3 置信度分数的价值置信度分数为模型的输出增加了宝贵的“不确定性量化”维度。我们分析了正确分类样本和错误分类样本的置信度分布见图16, 17。理想情况正确分类的样本真阳性TP、真阴性TN应具有高置信度接近1而错误分类的样本假阳性FP、假阴性FN应具有低置信度接近0.5。我们的发现在使用度量学习嵌入后SVM分类器的置信度分布基本符合这一规律。TP和TN样本的置信度密度峰值集中在0.8-1.0区间而FP和FN样本的置信度则广泛分布在0.5-1.0之间且密度较低。临床意义医生可以设定一个置信度阈值例如0.8。对于置信度高于此阈值的预测可以高度信任对于置信度低于此阈值的预测系统可以将其标记为“需要医生重点复核”。如图4所示随着置信度阈值的提高被系统“接受”的预测子集的准确率、灵敏度、特异度都会同步提升但代价是会有更多样本被“拒绝”需要人工处理。这为实现人机协同诊断提供了可操作的路径。实操心得置信度阈值的设定阈值的选择是一个权衡。高阈值带来高可靠预测但低召回率很多病例需要人工看。低阈值则相反。在实际部署中建议与临床医生共同确定这个阈值可以根据临床风险如漏诊恶性病变的风险成本来调整。例如在初筛场景可以设定较低阈值以提高灵敏度在辅助确诊场景则可设定较高阈值以保证特异度。5. 关键问题与调优经验5.1 孪生网络 vs. 三元组网络如何选择在我们的实验中两者性能相近但各有特点特性孪生网络 (Siamese)三元组网络 (Triplet)训练样本构造样本对 (正对/负对)样本三元组 (锚点, 正例, 负例)训练效率相对更高。只需构造正负对数据准备相对简单。相对较低。需要在线挖掘有效的三元组计算开销稍大。嵌入空间特性嵌入簇通常更紧凑类内方差小在PCA可视化中占据区域更集中。嵌入簇可能相对分散一些但类间边界有时更清晰。与分类器搭配与线性SVM搭配时略占优势平均高0.5-1%。与k-NN搭配时效果同样好。建议场景追求训练速度和稳定性且后续使用SVM分类时。当特别关注类间边界的最大化时。个人经验对于大多数病理图像分类任务从孪生网络开始是一个更稳妥高效的选择。它的训练更稳定超参数主要是边界值τ更容易调节。如果初步结果中类间分离度不够再尝试三元组网络。5.2 嵌入维度与分类器选择嵌入维度我们固定使用512维。这是一个经验性的平衡点维度太低如64可能信息损失严重无法充分分离复杂类别维度太高如2048则失去了降维和可视化的意义且可能引入噪声。在实践中可以在256到1024之间进行微调。分类器选择线性SVM是我们的默认推荐。它在嵌入空间线性可分性好的情况下非常强大且能直接给出基于距离的置信度。训练和预测速度极快。k-NN优势在于其极致的可解释性直接展示近邻。但当数据集很大时预测阶段的计算开销较高需计算与所有训练样本的距离。适合中小规模数据集或对解释性要求极高的场景。其他分类器我们也尝试了高斯核SVM、随机森林和MLP。如图8、9所示在嵌入空间已经良好聚类的前提下这些更复杂的模型并未带来显著提升反而增加了模型复杂度和训练时间。“简单的模型在好的特征上表现卓越”在这里得到了完美体现。5.3 处理类别不平衡与少样本问题病理数据集常面临类别不平衡问题如良性样本远多于恶性。我们的框架对此有天然鲁棒性在度量学习阶段在构建训练对或三元组时可以采用加权采样Weighted Sampling策略确保每个批次中少数类样本被更多地采样到从而让网络更好地学习其表征。在分类器训练阶段对于SVM可以设置class_weightbalanced参数自动调整惩罚项。对于k-NN可以结合距离加权投票让更近的邻居拥有更大的投票权重。对于Agios Pavlos这类小样本数据集仅300张图像21个患者基于患者的分割导致性能显著下降PLA 66.7%。这提示我们数据增强至关重要在训练度量学习网络时必须对图像进行强力的、病理学意义合理的数据增强如旋转、翻转、颜色抖动、弹性形变以模拟组织切片可能出现的变异。考虑跨数据集预训练可以先在BreaKHis等大型数据集上训练嵌入网络然后在Agios Pavlos上进行微调Fine-tuning利用迁移学习缓解数据不足问题。采用更保守的评估小数据集上应使用留一患者出Leave-One-Patient-Out交叉验证并更关注置信度分数对低置信度预测保持警惕。6. 部署考量与未来展望将这套系统部署到实际病理科工作流中还需要考虑以下几点计算效率推理过程非常高效。特征提取ResNet152前向传播和嵌入映射是固定计算量。分类SVM点积或k-NN距离计算几乎可以忽略不计。瓶颈可能在于将整张WSI切割成数百上千个图像块并进行批量处理。可以使用GPU并行处理来加速。交互式可视化工具开发一个界面允许病理医生上传或选择WSI中的某个区域系统实时显示该区域图像块在2D嵌入空间中的位置、其k个最近邻的训练样本可点击查看大图以及分类置信度。这是建立医生对AI信任的关键。持续学习与更新当系统遇到分类置信度低或医生纠正的误判案例时这些带有新标签的样本可以加入训练集对嵌入层和分类器进行增量更新使系统不断进化。未来可以探索的方向图神经网络整合当前方法处理的是独立的图像块。未来可以将一个WSI中的所有图像块构建成一个图节点是图像块嵌入边表示空间邻接或特征相似性然后使用图神经网络GNN来聚合整张切片的全局信息实现WSI级别的诊断。多模态融合结合病理图像与患者的基因组学、临床数据等多模态信息在嵌入空间中进行融合有望实现更精准的预后预测。无监督/自监督度量学习探索利用大量未标注的病理图像通过自监督学习如SimCLR, BYOL预训练一个通用的病理图像嵌入模型再在下游特定任务上用少量标注数据微调解决标注数据稀缺的根本难题。这套基于度量学习和可解释分类器的框架其价值在于它在追求高性能的同时没有牺牲可解释性。它不是一个无法理解的“黑箱”而是一个可以与医生对话、提供决策依据的“透明工具箱”。在AI日益深入医疗核心的今天这种透明和可信赖或许比单纯的几个百分点精度提升更为重要。

相关新闻