MLP-UNet:基于纯MLP架构的肾小球语义分割模型实践

发布时间:2026/5/26 18:07:17

MLP-UNet:基于纯MLP架构的肾小球语义分割模型实践 1. 项目概述为什么肾小球分割是数字病理学的“硬骨头”在肾脏疾病的诊断中肾活检组织的病理学分析是金标准。病理医生需要在高倍显微镜下从一张张染色的组织切片中手动寻找、计数并评估数以万计的肾小球——这些直径仅约200微米的微小结构是肾脏的“过滤单元”其形态、数量和损伤程度是判断肾小球肾炎、糖尿病肾病等疾病的关键。这个过程不仅极度耗时一个病例可能需要数小时而且高度依赖医生的经验存在显著的主观差异和观察者间不一致性。数字病理学的出现带来了转机。通过全玻片扫描仪可以将整张病理切片数字化为一张高达数十亿像素的Whole Slide Image。然而将医生从繁重的“数小球”工作中解放出来实现自动化分析首要且最核心的挑战就是精准的肾小球语义分割。这不仅仅是画个圈那么简单。肾小球在PAS过碘酸雪夫染色下虽然轮廓相对清晰但依然面临诸多挑战与周围肾小管、间质组织的对比度差异、因切片角度造成的形态变异、部分硬化或病变肾小球边界的模糊、以及图像中存在的染色伪影、褶皱、杂质等干扰。传统的图像处理方法如基于边缘检测或手工特征如HOG的方法在这些复杂场景下鲁棒性很差。深度学习尤其是基于卷积神经网络CNN的U-Net及其变体在过去几年已成为医学图像分割的事实标准。它们通过端到端的学习能够捕捉到更深层次、更抽象的特征。但CNN也有其局限卷积核的局部感受野限制了其捕获长距离依赖关系的能力而这对于理解一个结构在整张图像中的上下文位置有时很重要。随后视觉TransformerViT及其混合架构如TransUNet通过自注意力机制解决了这个问题取得了更优的性能但其二次方的计算复杂度和对大规模预训练的依赖又给实际部署带来了新的负担。正是在这样的背景下我们开始思考有没有一种架构既能保持甚至超越现有模型的性能又更轻量、更简单、无需预训练于是我们将目光投向了深度学习中最古老、最基础的组件——多层感知机。近期MLP-Mixer等纯MLP架构在图像分类任务上展现出的竞争力令人惊讶。它摒弃了卷积和注意力仅通过全连接层在空间和通道维度上进行信息混合就达到了与CNN、Transformer媲美的效果。这启发我们能否将这种简洁而强大的思想迁移到医学图像分割任务中MLP-UNet便是我们对这个问题的回答。我们的核心目标是设计一个专为肾小球分割优化的模型它在精度上不输于先进的TransUNet但在模型复杂度和训练成本上更具优势让高性能的AI辅助诊断工具能更便捷地集成到病理科的实际工作流中。2. 核心架构解析MLP-UNet是如何工作的MLP-UNet的整体设计哲学是“站在巨人的肩膀上创新”。我们借鉴了TransUNet成功的编码器-解码器框架但对其核心的编码器模块进行了彻底的“换芯”手术。整个模型管道可以清晰地分为三个部分CNN嵌入模块、MLP编码器模块和级联上采样解码器模块。下面我们来拆解每一个部分的设计动机与实现细节。2.1 整体框架与设计动机我们的通用模型框架如图2所示它是一个清晰的“嵌入-编码-解码”三级流水线。输入是一张256x256的RGB肾组织图像块及其对应的二值分割掩膜前景为肾小球背景为非肾小球。为什么是三级结构而不是直接端到端局部特征提取嵌入模块最底层的、像素级的纹理、边缘、颜色特征对于分割边界至关重要。CNN在提取这种局部特征方面有着与生俱来的优势且计算高效。因此我们保留了一个轻量级的CNN前端基于ResNet V2的预激活残差块作为嵌入模块将原始图像下采样并转化为一系列高维特征图。这相当于为模型提供了一个丰富的“视觉词典”。全局上下文建模编码器模块这是MLP-UNet的创新核心。经过嵌入模块得到的特征图被展平并送入编码器模块。这里我们完全摒弃了卷积和自注意力而是使用纯MLP层MLP-Mixer或ResMLP来混合这些特征。MLP层通过全连接操作能够让任何一个空间位置或称“令牌”的特征与所有其他位置的特征进行交互从而高效地建立全局上下文关系。这对于判断一个像素是否属于肾小球至关重要因为模型需要结合整个图像块的上下文信息而不仅仅是周围几像素。高分辨率重建解码器模块编码器输出的特征经过了高度抽象但空间分辨率较低。为了得到像素级的分割图我们需要将其上采样回原始分辨率。这里我们采用了TransUNet中的级联上采样器。它通过跳跃连接将编码器每一层的输出与嵌入模块中对应分辨率的CNN特征图进行拼接然后逐步进行3x3卷积、ReLU激活和双线性上采样。这种方式能有效融合低层的高分辨率细节和高层的语义信息避免在上采样过程中丢失局部信息从而生成边界清晰的分割结果。这个设计的精妙之处在于它让每个模块各司其职CNN负责“看细节”MLP负责“想全局”上采样器负责“画出来”。相比于纯CNN的U-Net我们引入了强大的全局建模能力相比于Transformer-based的模型我们用计算更高效的MLP替代了自注意力。2.2 嵌入模块从像素到特征的桥梁嵌入模块基于经典的ResNet V2架构我们使用了其前几个阶段通常到stage 3或4具体取决于下采样倍数。我们选择ResNet V2而非V1是因为其“预激活”结构BN-ReLU-Conv的顺序被证明具有更优的训练动态和性能。注意在医学图像任务中直接使用在ImageNet上预训练的CNN权重作为嵌入模块的初始化通常能带来显著的性能提升和更快的收敛速度这是一种非常有效的迁移学习策略。即使后续的MLP编码器是从头训练一个良好的特征提取器也能为整个模型奠定坚实的基础。该模块的工作流程如下输入图像256x256x3经过一系列残差下采样块空间尺寸逐步减小例如到32x32通道数逐步增加例如到512。在每一个下采样阶段结束后我们不仅将特征图传递给下一层还将其保存下来作为后续解码器跳跃连接的来源。最终从嵌入模块输出的特征图会被展平。假设特征图尺寸为(H, W, C)我们将其重塑为(N, C)的形状其中N H * W代表“令牌”或“补丁”的数量C是特征维度。这个(N, C)的矩阵就是送入MLP编码器的“句子”。2.3 编码器模块MLP的两种“混合”艺术这是MLP-UNet的灵魂。我们探索了两种基于MLP的编码器MLP-Mixer和ResMLP。它们的共同点是都完全由全连接层和激活函数构成没有卷积也没有注意力。2.3.1 MLP-Mixer编码器空间与通道的交替舞蹈MLP-Mixer的核心思想是分离空间混合与通道混合。它认为卷积和注意力机制隐式地同时处理了空间和通道信息而我们可以将其显式地分解为两个独立的步骤。令牌混合MLP作用于空间维度N个令牌。它让不同空间位置的特征进行交流。例如图像左上角的一个特征可以与右下角的特征直接交互从而让模型知道“这里有一个肾小球那么它周围的组织应该是什么样子”。这替代了卷积的局部感受野和Transformer的全局注意力。通道混合MLP作用于通道维度C个特征通道。它负责融合不同通道的信息学习如何组合从低级边缘到高级语义的各种特征。一个Mixer块由两个这样的MLP子层组成每个子层前后都有层归一化LayerNorm并配有残差连接。多个这样的块堆叠起来就构成了MLP-Mixer编码器。它的计算复杂度与令牌数量N呈线性关系远低于Transformer的O(N^2)。2.3.2 ResMLP编码器更进一步的简化与稳定ResMLP可以看作是MLP-Mixer的一个进化版它做了两个关键简化用仿射变换替代层归一化LayerNorm在训练时需要计算批统计量均值和方差这在批量较小时不稳定。ResMLP使用可学习的仿射变换Aff(x) Diag(α)x β来对每个特征进行缩放和偏移。α和β是可训练参数。这样做的好处是仿射变换不依赖于批数据在推理时没有任何计算开销训练也更稳定。更简洁的块结构一个ResMLP块同样包含跨补丁空间和跨通道两个子层但结构更加紧凑。在我们的实验中ResMLP-UNet以更少的参数量约6200万取得了与Mixer-UNet约7900万参数和TransUNet约7600万参数相近的性能这证明了其设计的高效性。2.4 解码器模块从抽象回到具体解码器接收来自编码器输出的、已经过全局上下文建模的特征。由于编码器保持了输入令牌的数量N我们首先将其重塑回空间格式(H, W, C)。随后级联上采样器开始工作特征融合将当前层的特征与来自嵌入模块对应层的、通过跳跃连接传来的CNN特征进行通道维度上的拼接Concatenation。CNN特征提供了高分辨率的细节信息而编码器特征提供了高级的语义指导。卷积与上采样对拼接后的特征进行3x3卷积和ReLU激活然后进行2倍双线性上采样。迭代重复步骤1和2直到特征图的空间尺寸恢复到与输入图像相同256x256。输出最后通过一个1x1卷积层将通道数映射为分类数对于二分类分割就是1个通道并通过Sigmoid激活函数输出每个像素属于肾小球的概率图。这个过程的直观理解是编码器告诉解码器“哪里大概有一个肾小球”而跳跃连接提供的细节告诉解码器“这个肾小球的边界具体在哪里拐弯”。3. 从数据到模型完整的实现流程与实操要点有了理论架构下一步就是将其转化为可运行的代码并在真实数据上进行训练和评估。这一部分我将结合我们使用HuBMAP肾脏数据集的具体经验详细拆解每一个步骤并分享其中的关键决策和避坑指南。3.1 数据准备与预处理质量决定上限我们使用的数据来自人类生物分子图谱计划HuBMAP的肾脏数据集包含20张PAS染色的全切片图像WSI并提供了肾小球的精细标注。原始WSI的尺寸巨大通常超过100,000 x 100,000像素无法直接送入网络。因此预处理是第一步也是影响模型性能的关键。我们的预处理流水线如下分块使用滑动窗口将每张WSI切割成1024x1024像素的小块tiles步长可以设置为512以增加数据量。这一步会产生数万个小图像块。筛选并非所有图像块都包含有价值的信息。很多块可能只是空白背景或极少染色的组织。我们采用了一种基于颜色饱和度的自动筛选方法将图像从RGB转换到HSV颜色空间。计算每个图像块的平均饱和度S通道。设定一个经验阈值我们经过实验确定为40。饱和度低于此阈值的块被视为“信息贫乏”而被丢弃。实操心得这个阈值需要根据你的具体染色方法和扫描仪进行调整。可以随机抽样几百个块人工判断是否包含组织然后绘制饱和度直方图观察有组织块和无组织块的分布从而确定一个合理的分界点。盲目设置阈值会导致丢弃有用数据或引入过多噪声。下采样为了控制计算成本并加速实验迭代我们将筛选后的1024x1024块下采样至256x256。这必然会丢失一些细节但对于初步验证架构是可行的。在最终追求最高精度时应考虑使用512x512或保留1024x1024的分辨率。数据增强医学图像数据通常稀缺增强是防止过拟合、提升模型泛化能力的利器。我们采用了在线增强on-the-fly augmentation即在每个训练周期epoch中随机对图像施加变换。我们的增强策略包括确定性增强概率1.0水平/垂直翻转、90度随机旋转。这些变换不改变组织的生物学形态能直接扩充数据。随机性增强概率0.5随机亮度/对比度调整、轻微的弹性形变、饱和度微调。这些模拟了染色差异、切片变形等真实情况。# 示例使用Albumentations库定义增强管道 import albumentations as A train_transform A.Compose([ A.HorizontalFlip(p0.5), A.VerticalFlip(p0.5), A.RandomRotate90(p0.5), A.ShiftScaleRotate(shift_limit0.0625, scale_limit0.1, rotate_limit15, p0.5), A.RandomBrightnessContrast(brightness_limit0.1, contrast_limit0.1, p0.3), A.HueSaturationValue(hue_shift_limit10, sat_shift_limit20, val_shift_limit10, p0.3), A.ElasticTransform(alpha1, sigma50, alpha_affine50, p0.3), ])3.2 模型构建与训练策略我们使用PyTorch框架实现了MLP-UNet。以下是几个关键实现细节1. 嵌入模块实现我们并没有使用完整的ResNet而是截取其前四个阶段stem, stage1, stage2, stage3并去掉了最后的全局池化层和全连接层。这样可以在多个尺度上获取特征图用于跳跃连接。2. MLP编码器实现以ResMLP块为例其PyTorch核心代码如下import torch.nn as nn class ResMLPBlock(nn.Module): def __init__(self, num_patches, hidden_dim, mlp_ratio4): super().__init__() # 跨补丁空间子层 self.cross_patch nn.Sequential( Affine(hidden_dim), # 仿射变换替代LayerNorm nn.Linear(num_patches, num_patches), # 空间混合MLP ) # 跨通道特征子层 self.cross_channel nn.Sequential( Affine(hidden_dim), nn.Linear(hidden_dim, hidden_dim * mlp_ratio), nn.GELU(), nn.Linear(hidden_dim * mlp_ratio, hidden_dim) ) def forward(self, x): # x shape: (batch_size, num_patches, hidden_dim) # 跨补丁混合: 对特征维度做全连接 residual x x self.cross_patch(x.transpose(1, 2)).transpose(1, 2) x x residual # 残差连接 # 跨通道混合 residual x x self.cross_channel(x) x x residual # 残差连接 return x class Affine(nn.Module): 仿射变换层独立于批统计量 def __init__(self, dim): super().__init__() self.alpha nn.Parameter(torch.ones(1, 1, dim)) self.beta nn.Parameter(torch.zeros(1, 1, dim)) def forward(self, x): return self.alpha * x self.beta3. 损失函数与评估指标对于医学图像分割尤其是前景肾小球与背景严重不平衡的场景Dice Loss是比标准交叉熵损失更好的选择。它直接优化Dice系数迫使模型关注前景区域的重叠度。def dice_loss(pred, target, smooth1e-5): pred pred.contiguous().view(-1) target target.contiguous().view(-1) intersection (pred * target).sum() dice (2. * intersection smooth) / (pred.sum() target.sum() smooth) return 1 - dice我们同时使用Dice系数作为评估指标它直观地反映了预测分割区域与真实标注区域的重合程度。4. 训练配置优化器Adam初始学习率设为1e-4。这是一个比较稳妥的起点。学习率调度使用ReduceLROnPlateau当验证集Dice系数在连续5个epoch内不再提升时将学习率乘以0.5。批量大小在单张11GB显存的RTX 2080 Ti上我们设置为16。如果显存不足可以减小批量大小但可能需要适当调整学习率。训练轮数50个epoch。我们观察到MLP-UNet大约在10-15个epoch后就能达到不错的性能后续是缓慢提升。3.3 五折交叉验证与结果分析为了客观评估模型性能并减少数据划分的偶然性我们采用了五折交叉验证。将15张训练用WSI对应的所有图像块随机分成5份轮流使用其中4份训练1份验证重复5次。结果解读基于我们的实验数据模型参数量 (百万)平均Dice系数 (%)最佳折Dice (%)最差折Dice (%)需预训练U-Net (DenseNet201)~2079.3092.5852.90是TransUNet (ViT)~7690.5893.0585.71是Mixer-UNet~7989.9093.0383.83否ResMLP-UNet (我们的)~6289.9693.0683.29否关键发现编码器模块的巨大价值对比U-Net可视为我们框架去掉编码器模块和MLP-UNet/TransUNet引入编码器无论是Transformer还是MLP带来了超过10%的平均Dice系数提升。这证明了在分割任务中显式地建模长距离全局上下文至关重要。MLP-UNet的竞争力ResMLP-UNet在无需任何ImageNet预训练的情况下取得了与预训练的TransUNet仅差0.6%的平均性能同时参数量减少了约20%。这是一个非常有力的结果意味着我们可以从一个更简单、更轻量的模型开始训练节省大量的预训练计算资源和时间。鲁棒性分析观察第5折最差的一折的结果。U-Net的性能出现了灾难性下降Dice仅52.9%而TransUNet和MLP-UNet则保持了相对稳定Dice 83%。这进一步说明具备强大全局建模能力的模型对于数据分布的变化该折训练数据较少验证数据较多具有更好的鲁棒性。训练动态从训练曲线看MLP-UNet在训练初期前10个epoch的Dice系数上升速度与TransUNet相当甚至更快。这表明MLP架构同样具备强大的拟合能力。TransUNet由于经过了预训练其性能饱和得更早约30个epoch而MLP-UNet则在50个epoch内持续缓慢提升。4. 常见问题、排查技巧与未来方向在实际复现和应用MLP-UNet的过程中你可能会遇到一些典型问题。以下是我在实验过程中总结的一些排查思路和技巧。4.1 训练不稳定或Dice系数不上升问题现象损失值Loss震荡剧烈或Dice系数始终在很低水平徘徊。排查步骤检查数据与标注这是最常见的问题根源。首先可视化一批训练数据及其对应的标注掩膜确保图像读取、增强和标注对齐Alignment是正确的。一个常见的错误是图像和掩膜在应用增强时没有使用相同的随机种子。检查数据分布计算一下你的训练集中前景像素肾小球占总像素的比例。如果比例极低如1%即使是随机预测Dice系数也可能很低。此时需要确认Dice Loss是否正常工作或者考虑结合Focal Loss等来处理极端类别不平衡。降低学习率尝试将初始学习率从1e-4降低到5e-5或1e-5。MLP层对学习率可能比卷积层更敏感。检查梯度在训练初期打印出模型各层的梯度范数。如果出现梯度爆炸值极大或梯度消失值接近0需要考虑调整权重初始化如使用Xavier或Kaiming初始化或在MLP块中添加更严格的归一化/仿射变换。简化模型先尝试一个非常浅的MLP-UNet例如只有2个ResMLP块看它能否在训练集上过拟合。如果连训练集都学不好那可能是模型结构或代码有根本性错误。4.2 模型过拟合问题现象训练集Dice系数很高但验证集Dice系数很低且差距随着训练持续拉大。解决方案增强数据增强增加更多样化的随机增强如高斯噪声、模糊、网格畸变等。医学图像中染色差异、切片折叠是常见噪声模拟这些情况有助于提升泛化能力。添加正则化在MLP层或卷积层后加入Dropout。对于MLP-UNet可以在通道混合MLP的两个全连接层之间加入Dropout。使用权重衰减在优化器Adam中设置一个较小的权重衰减如1e-4这等同于L2正则化可以防止权重过大。早停密切监控验证集Dice系数当其在连续10-15个epoch内不再提升时停止训练并回滚到验证集性能最佳的模型权重。4.3 预测结果边界粗糙或存在小碎片问题现象模型预测的分割掩膜边界不光滑呈锯齿状或者在背景中出现零星的小块错误预测。后处理技巧概率阈值调优模型输出的是每个像素的概率图。默认阈值是0.5但你可以根据验证集调整这个阈值以在精确率和召回率之间取得最佳平衡。使用ROC曲线或PR曲线来辅助选择。连通域分析肾小球通常是闭合的、具有一定面积的连通区域。可以使用OpenCV的findContours或connectedComponentsWithStats函数过滤掉面积过小比如小于50像素的预测区域这能有效去除大部分噪声碎片。形态学操作对二值化后的预测掩膜使用闭运算先膨胀后腐蚀可以填充小的孔洞使边界更光滑使用开运算先腐蚀后膨胀可以去除小的孤立点。核的大小需要根据图像分辨率谨慎选择例如3x3或5x5。4.4 未来改进与探索方向MLP-UNet为我们打开了一扇新的大门但仍有广阔的优化空间更高分辨率的输入我们的实验基于256x256的输入。将输入分辨率提升到512x512甚至1024x1024有望捕获更精细的肾小球毛细血管袢结构从而进一步提升分割精度尤其是边界准确性。但这会显著增加计算开销需要更强大的硬件或模型并行策略。探索更先进的MLP变体除了MLP-Mixer和ResMLP社区已经涌现出如ConvMLP结合了卷积的局部性和MLP的全局性、gMLP门控MLP等新架构。将它们作为编码器进行实验可能会发现性能更优或效率更高的组合。损失函数创新Dice Loss虽然有效但有时会导致训练不稳定。可以尝试结合边界损失如Boundary Loss、Hausdorff距离损失等直接优化分割边界的准确性这对于医学图像的精细分割尤为重要。应用于其他组织与染色目前工作集中在PAS染色的肾小球。验证MLP-UNet在HE染色、Masson染色等其他染色方法下的泛化能力以及在其他器官如肝脏、肺的显微结构分割任务上的表现是证明其通用性的关键。集成到临床工作流最终目标是辅助医生。开发一个完整的软件管道能够接收原始的WSI文件自动进行分块、推理、后处理并将分割结果如肾小球数量、面积、位置以可视化报告的形式叠加回原图提供给病理医生进行审核和确认这是从研究走向应用的必要一步。MLP-UNet的成功表明在医学图像分割这个对精度和效率都有严苛要求的领域回归基础、探索更简洁的架构设计是一条充满潜力的道路。它用更少的参数、更简单的操作达到了与复杂模型相媲美的效果为在计算资源有限的边缘设备或医院本地服务器上部署高性能AI辅助诊断工具提供了新的可能性。

相关新闻