)
存粹自己捣鼓记录过程1.论文复现过程用的清华镜像总是报错最后就是在原环境下的2.入门阶段代码import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoaderimport timmimport randomimport numpy as npimport os# HuggingFace国内镜像加速os.environ[HF_ENDPOINT] https://hf-mirror.com# 固定随机种子random.seed(42)np.random.seed(42)torch.manual_seed(42)# 强制使用CPU规避RTX5060 CUDA兼容报错device torch.device(cpu)batch_size 16epochs 3lr 1e-4# 数据预处理train_transform transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225])])val_transform transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225])])# 缩小数据集规模大幅降低CPU训练耗时from torchvision.datasets import FakeDatatrain_set FakeData(size1000,image_size(3, 224, 224),num_classes10,transformtrain_transform)val_set FakeData(size200,image_size(3, 224, 224),num_classes10,transformval_transform)# Windows系统num_workers设为0防止报错train_loader DataLoader(train_set, batch_sizebatch_size, shuffleTrue, num_workers0)val_loader DataLoader(val_set, batch_sizebatch_size, shuffleFalse, num_workers0)# 加载Swin-T模型model timm.create_model(swin_tiny_patch4_window7_224, pretrainedFalse, num_classes10)model model.to(device)# 损失函数与优化器criterion nn.CrossEntropyLoss()optimizer optim.Adam(model.parameters(), lrlr)# 训练一轮函数def train_epoch():model.train()total_loss 0.0for images, labels in train_loader:images, labels images.to(device), labels.to(device)optimizer.zero_grad()outputs model(images)loss criterion(outputs, labels)loss.backward()optimizer.step()total_loss loss.item()return total_loss / len(train_loader)# 验证一轮函数def val_epoch():model.eval()correct 0total 0with torch.no_grad():for images, labels in val_loader:images, labels images.to(device), labels.to(device)outputs model(images)_, preds torch.max(outputs, 1)total labels.size(0)correct (preds labels).sum().item()acc correct / totalreturn acc# 开始训练if __name__ __main__:for epoch in range(epochs):train_loss train_epoch()val_acc val_epoch()print(fEpoch [{epoch1}/{epochs}], Train Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f})3.完整后续代码# 修复Windows SSL证书报错必须放在最顶部import sslssl._create_default_https_context ssl._create_unverified_context# HuggingFace国内镜像加速解决预训练权重下载卡顿import osos.environ[HF_ENDPOINT] https://hf-mirror.comimport torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoaderimport timmimport randomimport numpy as npfrom torch.optim.lr_scheduler import CosineAnnealingLR# 固定随机种子保证实验可复现random.seed(42)np.random.seed(42)torch.manual_seed(42)# 优先使用GPU无GPU自动切换CPUdevice torch.device(cuda if torch.cuda.is_available() else cpu)print(f训练设备: {device})# 对齐Swin论文超参数batch_size 32epochs 100lr 1e-4weight_decay 0.05# 论文原版强数据增强策略train_transform transforms.Compose([transforms.Resize((256, 256)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(p0.5),transforms.ColorJitter(brightness0.4, contrast0.4, saturation0.4, hue0.1),transforms.ToTensor(),transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225])])# 验证集标准化预处理val_transform transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225])])# 加载CIFAR10真实数据集自动下载train_set datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtrain_transform)val_set datasets.CIFAR10(root./data, trainFalse, downloadTrue, transformval_transform)train_loader DataLoader(train_set, batch_sizebatch_size, shuffleTrue, num_workers0)val_loader DataLoader(val_set, batch_sizebatch_size, shuffleFalse, num_workers0)# 加载ImageNet预训练Swin-T模型model timm.create_model(swin_tiny_patch4_window7_224, pretrainedTrue, num_classes10)model model.to(device)# 论文采用AdamW优化器 权重衰减正则化criterion nn.CrossEntropyLoss()optimizer optim.AdamW(model.parameters(), lrlr, weight_decayweight_decay)# 余弦退火学习率调度策略scheduler CosineAnnealingLR(optimizer, T_maxepochs)# 单轮训练函数def train_epoch():model.train()total_loss 0.0for images, labels in train_loader:images, labels images.to(device), labels.to(device)optimizer.zero_grad()outputs model(images)loss criterion(outputs, labels)loss.backward()optimizer.step()total_loss loss.item()return total_loss / len(train_loader)# 单轮验证函数def val_epoch():model.eval()correct 0total 0with torch.no_grad():for images, labels in val_loader:images, labels images.to(device), labels.to(device)outputs model(images)_, preds torch.max(outputs, 1)total labels.size(0)correct (preds labels).sum().item()acc correct / totalreturn acc# 开始完整训练if __name__ __main__:for epoch in range(epochs):train_loss train_epoch()val_acc val_epoch()scheduler.step()print(fEpoch [{epoch1}/{epochs}], Train Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f})4. 过程中的知识点ResNet、ViT、Swin Transformer 网络结构对比总结结合你本次 Swin 复现实验分别梳理三类经典图像分类主干网络的结构特点、核心原理与差异。4.1 ResNet卷积神经网络 CNN 代表20151. 核心结构单元残差块Residual Block短路连接Shortcut Connection残差核心公式把输入直接绕过两层卷积加到卷积输出上解决深度网络梯度消失问题让网络可以堆叠几十、上百层。基础残差块结构2 层 3×3 卷积ResNet18/34瓶颈结构1×1 降维→3×3 卷积→1×1 升维ResNet50/101减少参数量。2. 整体网络流程输入图片 → 7×7 大卷积 最大下采样做粗特征提取依次堆叠 4 组残差块每组内部重复多个残差单元每组末尾通过步长为 2 的卷积做空间下采样特征图宽高减半、通道翻倍全局平均池化 全连接层输出分类结果。3. 核心特点依靠局部感受野卷积只能捕捉相邻像素的局部依赖长距离全局依赖需要多层卷积堆叠归纳偏置强自带平移、缩放不变性小数据集上收敛快、泛化好缺点深层堆叠后全局建模能力弱超大图像、长距离依赖场景精度瓶颈明显。4. 典型代表ResNet18/34浅层、ResNet50/101深层瓶颈结构。4.2 ViTVision Transformer纯 Transformer 视觉模型20201. 核心思想把图像转换成序列用 NLP 的 Transformer 编码器做全局建模2. 完整网络结构图像分块 Patch Embedding最关键预处理将 224×224 图片均等切分为 16×16 大小的小图块Patch一共 \(14×14196\) 个图块 每个 Patch 拉平为一维向量通过线性投影映射为固定维度的特征向量得到长度为 196 的特征序列。可学习组件Class Token额外增加 1 个全局分类向量拼在序列最前端最终用该向量做分类预测位置编码 Positional Embedding给每个 Patch 加上可学习位置信息Transformer 本身没有空间位置感知能力必须显式注入位置信息。堆叠多层 Transformer Encoder 编码器每一层编码器由两部分组成多头自注意力Multi-Head Self-Attention, MHSA所有 Patch 两两计算注意力直接建模全局任意像素之间的依赖关系不受距离限制前馈网络 FFN 层归一化 LN 残差短路连接。输出阶段 取出 Class Token 特征接入多层感知器 MLP 完成分类。3. 核心优缺点优点天生全局建模能力大数据集ImageNet下精度远超 ResNet缺点没有 CNN 的局部归纳偏置小数据集极易过拟合必须依赖大规模预训练全局自注意力复杂度Patch 数量越多计算量爆炸无法处理高分辨率大图4.3 Swin Transformer你本次复现模型ViT 改进版窗口化 Transformer1. 解决 ViT 全局注意力算力爆炸的痛点采用窗口自注意力 W-MSA 移位窗口 SW-MSA2. 分层四阶段金字塔结构对标 ResNet 下采样策略第 1 阶段图片切分为 4×4 的 Patch窗口内做窗口自注意力每阶段结束使用Patch Merging类似卷积下采样特征图宽高减半、通道翻倍构建多尺度特征金字塔交替使用窗口多头注意力 W-MSA只在固定窗口内算注意力算力移位窗口注意力 SW-MSA跨窗口交互弥补窗口隔离导致的全局信息缺失。3. 结构优势兼具 CNN 多尺度金字塔结构 Transformer 全局建模能力线性算力复杂度支持高分辨率图像分类、检测、分割下游任务通用性极强你实验中使用的 Swin-Tiny4 阶段 多层窗口 Transformer基于 ImageNet 预训练微调后在 CIFAR10 可以轻松达到 92% 准确率。5. 实验角度总结你最开始用随机初始化 Swin 在 Fake 数据集只有 32% 准确率印证了ViT 类模型缺少 CNN 归纳偏置小数据集必须依赖预训练切换 CIFAR10ImageNet 预训练 论文标准增强 / AdamW 优化器后Swin 可以充分发挥全局建模优势精度显著超越同等参数量 ResNetResNet 依靠卷积局部先验小数据集可以从零训练收敛ViT/Swin 必须依托大规模数据集预训练才能发挥结构优势。