
049、Swin Transformer Block 替换 Backbone 后两层 C3k2窗口注意力的层次化设计一、从一次诡异的mAP下降说起上个月调一个工业缺陷检测模型YOLOv11s在PCB板数据集上跑了50轮mAP0.5卡在78.3%死活上不去。我盯着TensorBoard里的特征图可视化发现Backbone最后两层输出的特征图在空间维度上几乎“糊”成一片——小尺寸缺陷比如0.5mm的划痕完全被背景噪声淹没了。当时第一反应是加注意力机制但SE、CBAM试了一圈提升不到0.5个点。后来翻Swin Transformer论文时突然意识到YOLOv11的C3k2模块本质是密集残差连接对局部细节的建模能力其实够但缺少跨窗口的全局交互。而Backbone最后两层特征图分辨率已经降到20x20和10x10这时候用窗口注意力反而比全局自注意力更划算——计算量小还能保留空间结构。于是动手把最后两个C3k2替换成Swin Transformer Block。结果mAP直接跳到81.7%涨了3.4个点。但别高兴太早第一次跑的时候mAP反而掉了0.8%后来发现是窗口划分的padding没处理好。下面把踩过的坑和最终方案拆开讲。二、Swin Block的核心设计别把窗口注意力当成黑盒Swin Transformer Block和C3k2最大的区别在于C3k2是“通道混合残差”Swin Block是“空间划分移位窗口”。替换后两层时需要特别注意两点窗口大小必须能被特征图尺寸整除。Backbone最后两层是20x20和10x10窗口大小设4x4的话20/45刚好整除但10/42.5会出问题。我一开始没处理结果PyTorch的window_partition函数直接报错。移位窗口的cyclic shift。Swin论文里用torch.roll实现但YOLO的推理流程里如果用了torch.jit.scriptroll操作会被优化掉导致结果不对。后面会给出替代方案。三、代码实现从C3k2到Swin Block的替换手术3.1 先定义Swin Transformer Block的核心组件importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassWindowAttention(nn.Module):def__init__(self,dim,window_size,num_heads,qkv_biasTrue):super().__init__()self.dimdim self.window_sizewindow_size# (Wh, Ww)self.num_headsnum_heads head_dimdim//num_heads self.scalehead_dim**-0.5# 这里踩过坑qkv的bias必须保留否则小模型收敛慢self.qkvnn.Linear(dim,dim*3,biasqkv_bias)self.projnn.Linear(dim,dim)# 相对位置偏置表别写成nn.ParameterList直接nn.Parameterself.relative_position_bias_tablenn.Parameter(torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1),num_heads))# 计算相对位置索引coords_htorch.arange(self.window_size[0])coords_wtorch.arange(self.window_size[1])coordstorch.stack(torch.meshgrid([coords_h,coords_w]))# 2, Wh, Wwcoords_flattentorch.flatten(coords,1)# 2, Wh*Wwrelative_coordscoords_flatten[:,:,None]-coords_flatten[:,None,:]# 2, Wh*Ww, Wh*Wwrelative_coordsrelative_coords.permute(1,2,0).contiguous()# Wh*Ww, Wh*Ww, 2relative_coords[:,:,0]self.window_size[0]-1relative_coords[:,:,1]self.window_size[1]-1relative_coords[:,:,0]*2*self.window_size[1]-1relative_position_indexrelative_coords.sum(-1)# Wh*Ww, Wh*Wwself.register_buffer(relative_position_index,relative_position_index)defforward(self,x,maskNone):B_,N,Cx.shape qkvself.qkv(x).reshape(B_,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)q,k,vqkv[0],qkv[1],qkv[2]# 别写成qkv.unbind(0)显存会炸qq*self.scale attn(q k.transpose(-2,-1))# 相对位置偏置relative_position_biasself.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0]*self.window_size[1],self.window_size[0]*self.window_size[1],-1)relative_position_biasrelative_position_bias.permute(2,0,1).contiguous()# nH, Wh*Ww, Wh*Wwattnattnrelative_position_bias.unsqueeze(0)ifmaskisnotNone:nWmask.shape[0]attnattn.view(B_//nW,nW,self.num_heads,N,N)mask.unsqueeze(1).unsqueeze(0)attnattn.view(-1,self.num_heads,N,N)attnattn.softmax(dim-1)x(attn v).transpose(1,2).reshape(B_,N,C)xself.proj(x)returnx3.2 窗口划分与合并这里最容易出bugdefwindow_partition(x,window_size):# x: (B, H, W, C)B,H,W,Cx.shape# 别这样写直接view会报错因为H和W可能不能被window_size整除# 正确做法先padpad_h(window_size-H%window_size)%window_size pad_w(window_size-W%window_size)%window_sizeifpad_h0orpad_w0:xF.pad(x,(0,0,0,pad_w,0,pad_h))# 注意pad顺序左、右、上、下H_pad,W_padx.shape[1],x.shape[2]xx.view(B,H_pad//window_size,window_size,W_pad//window_size,window_size,C)windowsx.permute(0,1,3,2,4,5).contiguous().view(-1,window_size*window_size,C)returnwindows,(H_pad,W_pad)defwindow_reverse(windows,window_size,H,W,pad_h,pad_w):Bint(windows.shape[0]/(H*W/window_size/window_size))xwindows.view(B,H//window_size,W//window_size,window_size,window_size,-1)xx.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)# 去掉paddingifpad_h0orpad_w0:xx[:,:H-pad_h,:W-pad_w,:].contiguous()returnx3.3 Swin Transformer Block主体classSwinTransformerBlock(nn.Module):def__init__(self,dim,input_resolution,num_heads,window_size7,shift_size0):super().__init__()self.dimdim self.input_resolutioninput_resolution self.num_headsnum_heads self.window_sizewindow_size self.shift_sizeshift_size# 这里踩过坑shift_size不能大于window_sizeifmin(self.input_resolution)self.window_size:self.shift_size0self.window_sizemin(self.input_resolution)self.norm1nn.LayerNorm(dim)self.attnWindowAttention(dim,window_size(self.window_size,self.window_size),num_headsnum_heads)self.norm2nn.LayerNorm(dim)self.mlpnn.Sequential(nn.Linear(dim,4*dim),nn.GELU(),nn.Linear(4*dim,dim))# 计算attention maskifself.shift_size0:H,Wself.input_resolution img_masktorch.zeros((1,H,W,1))h_slices(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))w_slices(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))cnt0forhinh_slices:forwinw_slices:img_mask[:,h,w,:]cnt cnt1mask_windows,_window_partition(img_mask,self.window_size)mask_windowsmask_windows.view(-1,self.window_size*self.window_size)attn_maskmask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)attn_maskattn_mask.masked_fill(attn_mask!0,float(-100.0)).masked_fill(attn_mask0,float(0.0))else:attn_maskNoneself.register_buffer(attn_mask,attn_mask)defforward(self,x):# x: (B, C, H, W)B,C,H,Wx.shape shortcutx xx.flatten(2).transpose(1,2)# (B, H*W, C)xself.norm1(x)xx.view(B,H,W,C)# 循环移位ifself.shift_size0:shifted_xtorch.roll(x,shifts(-self.shift_size,-self.shift_size),dims(1,2))else:shifted_xx# 划分窗口windows,(H_pad,W_pad)window_partition(shifted_x,self.window_size)# 窗口注意力attn_windowsself.attn(windows,maskself.attn_mask)# 合并窗口shifted_xwindow_reverse(attn_windows,self.window_size,H_pad,W_pad,pad_hH_pad-H,pad_wW_pad-W)# 反向移位ifself.shift_size0:xtorch.roll(shifted_x,shifts(self.shift_size,self.shift_size),dims(1,2))else:xshifted_x xx.view(B,H*W,C)xself.norm2(x)xself.mlp(x)xx.transpose(1,2).view(B,C,H,W)returnxshortcut3.4 替换YOLOv11 Backbone的最后两层找到ultralytics/nn/modules/block.py中的C3k2类在ultralytics/nn/modules/backbone.py中定位到最后两个stage。假设原始Backbone结构如下Stage3: C3k2(256, 512, 3, True) # 输出20x20 Stage4: C3k2(512, 512, 3, True) # 输出10x10替换为# 在backbone的__init__中self.stage3SwinTransformerBlock(dim512,input_resolution(20,20),num_heads8,window_size4,# 20/45整除shift_size2# 窗口大小的一半)self.stage4SwinTransformerBlock(dim512,input_resolution(10,10),num_heads8,window_size4,# 这里注意10不能被4整除但SwinBlock内部会处理paddingshift_size2)关键修改点在forward中需要把C3k2的输入输出通道对齐。原始C3k2的输入是256通道输出512而SwinBlock要求输入输出通道一致。所以需要在Stage3前加一个1x1卷积升维self.stage3_convnn.Conv2d(256,512,1)四、消融实验窗口大小和移位策略的影响在PCB缺陷数据集上训练集5000张测试集1000张8类缺陷用YOLOv11s做基准替换最后两层Swin Block训练100轮输入640x640配置mAP0.5mAP0.5:0.95参数量推理速度(ms)原始C3k278.3%52.1%9.2M2.1Swin Block (window4, shift0)79.8%53.6%10.1M2.8Swin Block (window4, shift2)81.7%55.4%10.1M3.0Swin Block (window7, shift3)80.2%53.9%10.1M3.5Swin Block (window8, shift4)79.5%52.8%10.1M3.8结论窗口大小4移位2效果最好因为20x20特征图用4x4窗口刚好5x5个窗口移位后能覆盖所有空间位置窗口太大7或8反而下降因为小特征图下窗口内像素太少自注意力退化成平均池化推理速度增加约1ms但mAP涨3.4个点性价比很高五、个人经验三个容易忽略的细节LayerNorm的位置Swin Block的LayerNorm放在attention和MLP之前pre-norm而C3k2用的是BatchNorm。替换后如果发现训练不稳定检查一下BN层的running_mean是否被冻结——我遇到过因为model.train()没正确设置导致BN统计量不更新mAP直接掉5个点。梯度检查点Swin Block的显存占用比C3k2高30%左右如果batch size设8爆显存可以在SwinBlock的forward里加torch.utils.checkpoint.checkpoint。但注意checkpoint不支持torch.roll操作需要把移位部分单独拎出来。混合精度训练Swin Block里的softmax在fp16下容易溢出建议在WindowAttention.forward里把attn转成fp32计算再转回fp16。代码里加一行attn attn.float().softmax(dim-1).half()就能解决。最后说句实在话Swin Block替换Backbone后两层不是万能药。如果你的数据集里目标尺寸都很大比如行人检测窗口注意力带来的提升可能不到1个点。但如果你做的是小目标检测比如遥感图像、工业缺陷这个改动值得一试——至少在我的三个项目里都稳定涨点。