别再死记ResNet18结构图了!用PyTorch代码逐层拆解输入输出尺寸变化

发布时间:2026/6/6 3:12:20

别再死记ResNet18结构图了!用PyTorch代码逐层拆解输入输出尺寸变化 用PyTorch代码逐层解剖ResNet18从张量维度变化理解残差网络当你第一次看到ResNet18的结构图时是否曾被那些密密麻麻的箭头和方块搞得晕头转向作为计算机视觉领域的里程碑式架构ResNet18通过残差连接解决了深度神经网络中的梯度消失问题。但理解它的最佳方式不是死记硬背结构图而是亲手用代码构建它观察每一层如何改变输入数据的维度。本文将带你用PyTorch从零开始构建ResNet18并通过打印每一层输出的张量形状height, width, channels来直观理解数据流。这种方法特别适合那些喜欢通过实践来学习的开发者——你不仅能看到理论还能立即验证每一层的效果。我们会重点关注初始卷积层和池化层如何快速降低空间维度四个残差块阶段中特征图尺寸的变化规律虚线残差连接中1x1卷积的通道调整机制全局平均池化如何替代全连接层减少参数1. 环境准备与基础概念在开始之前确保你已经安装了PyTorch。可以通过以下命令检查pip install torch torchvisionResNet18的核心创新是残差连接residual connection它允许梯度直接流过网络缓解了深度网络中的梯度消失问题。一个典型的残差块由两个3x3卷积层组成输入会与卷积后的输出相加输入 → 卷积1 → 卷积2 → → 输出 ↑____________|这种结构使得网络可以学习残差映射residual mapping而非直接映射这在理论上更容易优化。当特征图尺寸减半或通道数变化时需要使用1x1卷积调整残差连接的维度这就是结构图中的虚线连接。2. 构建ResNet18基础模块我们先定义残差块的基本结构。在PyTorch中这通常实现为一个包含两个卷积层的子模块import torch import torch.nn as nn class BasicBlock(nn.Module): expansion 1 def __init__(self, in_channels, out_channels, stride1): super().__init__() self.conv1 nn.Conv2d( in_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse ) self.bn1 nn.BatchNorm2d(out_channels) self.conv2 nn.Conv2d( out_channels, out_channels, kernel_size3, stride1, padding1, biasFalse ) self.bn2 nn.BatchNorm2d(out_channels) self.shortcut nn.Sequential() if stride ! 1 or in_channels ! self.expansion * out_channels: self.shortcut nn.Sequential( nn.Conv2d( in_channels, self.expansion * out_channels, kernel_size1, stridestride, biasFalse ), nn.BatchNorm2d(self.expansion * out_channels) ) def forward(self, x): out torch.relu(self.bn1(self.conv1(x))) out self.bn2(self.conv2(out)) out self.shortcut(x) out torch.relu(out) return out关键点解析shortcut连接当stride≠1或输入输出通道数不匹配时使用1x1卷积调整维度Batch Normalization每个卷积层后都跟批量归一化加速训练ReLU激活只在残差相加后应用一次非线性激活让我们测试一个残差块的效果block BasicBlock(64, 64) x torch.randn(1, 64, 56, 56) # (batch, channels, height, width) print(f输入形状: {x.shape}) out block(x) print(f输出形状: {out.shape}) # 应保持不变3. 逐层解析完整ResNet18现在我们将残差块组合成完整的ResNet18架构。按照原始论文ResNet18包含初始卷积层 (7x7, stride2)最大池化 (3x3, stride2)4个阶段各包含2个残差块全局平均池化和全连接层以下是完整实现class ResNet18(nn.Module): def __init__(self, num_classes1000): super().__init__() self.in_channels 64 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) self.bn1 nn.BatchNorm2d(64) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) self.layer1 self._make_layer(64, 2, stride1) self.layer2 self._make_layer(128, 2, stride2) self.layer3 self._make_layer(256, 2, stride2) self.layer4 self._make_layer(512, 2, stride2) self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc nn.Linear(512 * BasicBlock.expansion, num_classes) def _make_layer(self, out_channels, num_blocks, stride): strides [stride] [1] * (num_blocks - 1) layers [] for stride in strides: layers.append(BasicBlock(self.in_channels, out_channels, stride)) self.in_channels out_channels * BasicBlock.expansion return nn.Sequential(*layers) def forward(self, x): print(f\n初始输入: {x.shape}) x self.conv1(x) print(f初始卷积后: {x.shape}) x self.bn1(x) x torch.relu(x) x self.maxpool(x) print(f最大池化后: {x.shape}) x self.layer1(x) print(f阶段1后: {x.shape}) x self.layer2(x) print(f阶段2后: {x.shape}) x self.layer3(x) print(f阶段3后: {x.shape}) x self.layer4(x) print(f阶段4后: {x.shape}) x self.avgpool(x) print(f全局池化后: {x.shape}) x torch.flatten(x, 1) x self.fc(x) print(f全连接后: {x.shape}) return x4. 运行并观察维度变化让我们创建一个224x224的RGB图像输入观察各层输出model ResNet18() x torch.randn(1, 3, 224, 224) # 模拟ImageNet输入 out model(x)运行后会打印如下维度变化初始输入: torch.Size([1, 3, 224, 224]) 初始卷积后: torch.Size([1, 64, 112, 112]) 最大池化后: torch.Size([1, 64, 56, 56]) 阶段1后: torch.Size([1, 64, 56, 56]) 阶段2后: torch.Size([1, 128, 28, 28]) 阶段3后: torch.Size([1, 256, 14, 14]) 阶段4后: torch.Size([1, 512, 7, 7]) 全局池化后: torch.Size([1, 512, 1, 1]) 全连接后: torch.Size([1, 1000])关键维度变化解析层类型输入尺寸输出尺寸变化说明初始卷积[3,224,224][64,112,112]7x7卷积stride2使尺寸减半最大池化[64,112,112][64,56,56]3x3池化stride2再次减半阶段1[64,56,56][64,56,56]两个残差块保持尺寸阶段2[64,56,56][128,28,28]第一个残差块stride2降采样阶段3[128,28,28][256,14,14]同上降采样机制阶段4[256,14,14][512,7,7]最终降采样到7x7全局池化[512,7,7][512,1,1]空间维度压缩为1x15. 残差连接的特殊处理当特征图尺寸变化时如阶段2从56x56→28x28残差连接也需要相应调整。这是通过shortcut中的1x1卷积实现的# 在BasicBlock中查看shortcut block BasicBlock(64, 128, stride2) print(block.shortcut) # 输出: Sequential( # (0): Conv2d(64, 128, kernel_size(1, 1), stride(2, 2), biasFalse) # (1): BatchNorm2d(128, eps1e-05, momentum0.1, affineTrue, track_running_statsTrue) # )这种设计确保了即使维度变化残差连接也能与主路径的输出正确相加。你可以通过以下代码验证x torch.randn(1, 64, 56, 56) block BasicBlock(64, 128, stride2) print(f输入形状: {x.shape}) out block(x) print(f输出形状: {out.shape}) # torch.Size([1, 128, 28, 28])6. 实际应用技巧与常见问题在实际使用ResNet18时有几个实用技巧值得注意输入尺寸灵活性虽然原始设计针对224x224输入但通过调整可以接受其他尺寸。全局平均池化使其对输入尺寸不敏感。特征提取可以移除最后的全连接层将ResNet18作为特征提取器features nn.Sequential(*list(model.children())[:-1]) feature_vector features(x) # 得到512维特征训练技巧使用预训练权重加速收敛学习率 warmup 有助于稳定训练对BatchNorm层小心处理特别是迁移学习时常见问题排查维度不匹配错误通常发生在shortcut连接处检查stride和通道数设置梯度消失确保残差连接正常工作可以检查中间层梯度性能不佳尝试调整初始学习率或添加更多数据增强7. 可视化工具辅助理解除了代码打印外可以使用以下工具更直观地观察网络TorchSummaryfrom torchsummary import summary summary(model, (3, 224, 224))TensorBoard可视化from torch.utils.tensorboard import SummaryWriter writer SummaryWriter() writer.add_graph(model, x) writer.close()手工绘制数据流根据打印的维度变化绘制自己的简化结构图这比记忆标准图更有效。

相关新闻