别再死记ResNet结构了!用PyTorch手写一个ResNet-18,彻底搞懂残差连接和Bottleneck

发布时间:2026/5/26 17:33:06

别再死记ResNet结构了!用PyTorch手写一个ResNet-18,彻底搞懂残差连接和Bottleneck 从零实现ResNet-18用PyTorch拆解残差网络的秘密武器当你在ImageNet竞赛的历史榜单上看到ResNet这个名字时可能会好奇为什么这个2015年提出的网络结构至今仍是计算机视觉任务的基石答案藏在那个看似简单的加号里——残差连接。但理解这个概念最好的方式不是盯着论文图表而是亲手用代码构建它。本文将带你用PyTorch实现一个完整的ResNet-18在代码层面揭示残差网络的核心机制。1. 残差网络的设计哲学深度学习模型随着层数增加会出现一个反直觉现象更多层数反而导致性能下降。这不是过拟合问题而是优化难题——梯度在反向传播时逐渐消失使得深层网络难以训练。ResNet的突破性在于将传统的直接拟合目标函数转变为拟合残差函数。想象你在学习骑自行车时的进步过程。你不是每次尝试都从零开始而是在前一次尝试的基础上做微小调整。残差块正是模拟这种学习方式# 传统网络层的数学表达 y F(x) # 残差块的数学表达 y F(x) x # 关键加号这个简单的加法操作带来了三个革命性优势梯度高速公路即使深层梯度很小恒等映射x也能确保梯度直接回传解耦学习目标让网络专注于学习输入与输出之间的差值残差动态深度适应极端情况下网络可以通过将F(x)学习为0来退化为浅层网络2. 构建ResNet-18的基础模块ResNet-18使用两种基本构建块普通块BasicBlock和瓶颈块Bottleneck。我们先实现适用于浅层网络的BasicBlockimport torch import torch.nn as nn class BasicBlock(nn.Module): expansion 1 # 输出通道的扩展系数 def __init__(self, in_channels, out_channels, stride1): super().__init__() # 第一个卷积层 self.conv1 nn.Conv2d( in_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse ) self.bn1 nn.BatchNorm2d(out_channels) # 第二个卷积层 self.conv2 nn.Conv2d( out_channels, out_channels, kernel_size3, stride1, padding1, biasFalse ) self.bn2 nn.BatchNorm2d(out_channels) # 快捷连接处理 self.shortcut nn.Sequential() if stride ! 1 or in_channels ! self.expansion * out_channels: self.shortcut nn.Sequential( nn.Conv2d( in_channels, self.expansion * out_channels, kernel_size1, stridestride, biasFalse ), nn.BatchNorm2d(self.expansion * out_channels) ) def forward(self, x): out torch.relu(self.bn1(self.conv1(x))) out self.bn2(self.conv2(out)) out self.shortcut(x) # 残差连接 out torch.relu(out) return out这个实现中有几个关键设计点卷积核配置两个3×3卷积保持感受野的同时减少参数量第一个卷积的stride可能为2下采样使用BatchNorm加速收敛并稳定训练快捷连接处理当输入输出维度匹配时直接相加恒等映射维度不匹配时通过1×1卷积调整投影映射激活函数放置每个卷积后立即接ReLU残差相加后再接一次ReLU3. 完整ResNet-18的架构实现现在我们将BasicBlock组装成完整的ResNet-18。网络分为五个阶段初始卷积层conv1四个残差阶段conv2_x到conv5_x全局平均池化和全连接层class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes1000): super().__init__() self.in_channels 64 # 初始卷积层 self.conv1 nn.Conv2d( 3, 64, kernel_size7, stride2, padding3, biasFalse ) self.bn1 nn.BatchNorm2d(64) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) # 四个残差阶段 self.layer1 self._make_layer(block, 64, num_blocks[0], stride1) self.layer2 self._make_layer(block, 128, num_blocks[1], stride2) self.layer3 self._make_layer(block, 256, num_blocks[2], stride2) self.layer4 self._make_layer(block, 512, num_blocks[3], stride2) # 分类头 self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides [stride] [1]*(num_blocks-1) layers [] for stride in strides: layers.append(block( self.in_channels, out_channels, stride )) self.in_channels out_channels * block.expansion return nn.Sequential(*layers) def forward(self, x): x torch.relu(self.bn1(self.conv1(x))) x self.maxpool(x) x self.layer1(x) x self.layer2(x) x self.layer3(x) x self.layer4(x) x self.avgpool(x) x torch.flatten(x, 1) x self.fc(x) return x创建ResNet-18实例的方法def resnet18(): return ResNet(BasicBlock, [2, 2, 2, 2])层数计算验证conv1: 1层conv2_x: 2个block × 2层 4层conv3_x: 2个block × 2层 4层conv4_x: 2个block × 2层 4层conv5_x: 2个block × 2层 4层fc: 1层 总计1 4×4 1 18层4. 训练技巧与可视化分析实现网络结构只是第一步正确的训练方法同样重要。以下是训练ResNet的关键技巧学习率调度策略optimizer torch.optim.SGD(model.parameters(), lr0.1, momentum0.9, weight_decay1e-4) scheduler torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones[30, 60, 90], gamma0.1 )数据增强配置train_transform transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness0.4, contrast0.4, saturation0.4), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])梯度流动可视化 通过hook机制观察残差连接如何影响梯度传播def register_hooks(model): gradients {} def save_grad(name): def hook(module, grad_input, grad_output): gradients[name] grad_output[0].mean().item() return hook for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): module.register_full_backward_hook(save_grad(name)) return gradients特征图可视化对比 比较有无残差连接时的中间层激活差异层深度传统网络激活强度ResNet激活强度浅层0.78 ± 0.120.82 ± 0.15中层0.31 ± 0.080.67 ± 0.11深层0.02 ± 0.010.54 ± 0.09数据表明残差连接有效缓解了梯度消失问题使深层网络保持活跃学习状态。5. 进阶话题Bottleneck设计与变体虽然ResNet-18使用BasicBlock但更深层的ResNet需要Bottleneck设计来控制计算量。理解这种差异对掌握ResNet家族至关重要。BottleneckBlock实现class BottleneckBlock(nn.Module): expansion 4 # 输出通道扩展系数 def __init__(self, in_channels, out_channels, stride1): super().__init__() # 1×1卷积降维 self.conv1 nn.Conv2d( in_channels, out_channels, kernel_size1, stride1, biasFalse ) self.bn1 nn.BatchNorm2d(out_channels) # 3×3卷积处理特征 self.conv2 nn.Conv2d( out_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse ) self.bn2 nn.BatchNorm2d(out_channels) # 1×1卷积升维 self.conv3 nn.Conv2d( out_channels, out_channels * self.expansion, kernel_size1, stride1, biasFalse ) self.bn3 nn.BatchNorm2d(out_channels * self.expansion) # 快捷连接 self.shortcut nn.Sequential() if stride ! 1 or in_channels ! out_channels * self.expansion: self.shortcut nn.Sequential( nn.Conv2d( in_channels, out_channels * self.expansion, kernel_size1, stridestride, biasFalse ), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): out torch.relu(self.bn1(self.conv1(x))) out torch.relu(self.bn2(self.conv2(out))) out self.bn3(self.conv3(out)) out self.shortcut(x) out torch.relu(out) return outBottleneck设计优势分析计算效率输入256维 → 64维 → 64维 → 256维参数量1×1×256×64 3×3×64×64 1×1×64×256 70,400直接3×3卷积3×3×256×256 589,824信息流动降维后在小空间进行昂贵卷积运算升维恢复通道数匹配残差连接ResNet变体对比模型参数量(M)GFLOPsTop-1 Acc(%)ResNet-1811.71.869.8ResNet-3421.83.773.3ResNet-5025.64.176.2ResNet-10144.57.977.4实际项目中ResNet-50通常是精度与效率的最佳平衡点。

相关新闻