从零开始:用PyTorch和CIFAR-10复现DDPM扩散模型(附完整代码与调参心得)

发布时间:2026/5/23 6:22:46

从零开始:用PyTorch和CIFAR-10复现DDPM扩散模型(附完整代码与调参心得) 从零实现DDPM用PyTorch在CIFAR-10上构建扩散模型的实战指南当我在实验室第一次看到扩散模型生成的猫咪图像时那种从噪声中逐渐浮现出清晰轮廓的过程就像观看一场数字魔术。不同于GAN的对抗训练或VAE的隐变量压缩扩散模型通过模拟物理中的扩散现象以更稳定的方式实现了高质量图像生成。本文将带你用PyTorch从零开始实现DDPMDenoising Diffusion Probabilistic Models并在CIFAR-10数据集上完成整个训练流程。1. 环境准备与数据加载在开始之前确保你的Python环境已安装PyTorch 1.12和Torchvision。对于GPU加速建议使用CUDA 11.3及以上版本pip install torch torchvision matplotlib numpyCIFAR-10数据集包含60,000张32x32彩色图像分为10个类别。我们使用Torchvision提供的接口加载数据并进行标准化处理transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_set datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtransform) train_loader DataLoader(train_set, batch_size128, shuffleTrue, num_workers4)注意将像素值归一化到[-1,1]范围是扩散模型的常见做法这与后续的噪声添加过程密切相关。2. 扩散过程的核心组件2.1 噪声调度器设计噪声调度器控制着前向扩散过程中噪声的添加节奏。DDPM采用线性beta调度这是最基础的策略class LinearScheduler: def __init__(self, timesteps1000, beta_start1e-4, beta_end0.02): self.timesteps timesteps self.betas torch.linspace(beta_start, beta_end, timesteps) self.alphas 1. - self.betas self.alpha_cumprod torch.cumprod(self.alphas, dim0) def add_noise(self, x, t, noise): sqrt_alpha torch.sqrt(self.alpha_cumprod[t]) sqrt_one_minus_alpha torch.sqrt(1. - self.alpha_cumprod[t]) return sqrt_alpha * x sqrt_one_minus_alpha * noise实际应用中你可以尝试不同的调度策略调度类型公式特点线性βₜ β₀ t(β₁-β₀)/T简单直接默认选择余弦见论文Improved DDPM高频噪声添加更平缓平方βₜ (β₀ t(β₁-β₀)/T)²初期变化慢后期变化快2.2 UNet架构实现我们的UNet需要处理时间步信息并预测噪声。关键组件包括时间步嵌入将离散时间步转换为连续向量残差块保持梯度流动的基础单元注意力机制在高分辨率层提升生成质量class TimeEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.dim dim half_dim dim // 2 emb math.log(10000) / (half_dim - 1) emb torch.exp(torch.arange(half_dim, dtypetorch.float) * -emb) self.register_buffer(emb, emb) def forward(self, t): emb t.float() * self.emb emb torch.cat((emb.sin(), emb.cos()), dim-1) return emb class ResBlock(nn.Module): def __init__(self, in_c, out_c, t_dim): super().__init__() self.conv1 nn.Conv2d(in_c, out_c, 3, padding1) self.conv2 nn.Conv2d(out_c, out_c, 3, padding1) self.time_proj nn.Linear(t_dim, out_c) self.shortcut nn.Conv2d(in_c, out_c, 1) if in_c ! out_c else nn.Identity() def forward(self, x, t): h self.conv1(x) h self.time_proj(t)[:, :, None, None] h F.silu(h) h self.conv2(h) return h self.shortcut(x)3. 训练流程与技巧3.1 损失函数实现DDPM使用简化的均方误差损失直接预测噪声def p_losses(model, x0, t, noiseNone): if noise is None: noise torch.randn_like(x0) xt scheduler.add_noise(x0, t, noise) predicted_noise model(xt, t) return F.mse_loss(predicted_noise, noise)3.2 训练循环优化在实际训练中我发现以下几个技巧能显著提升效果学习率预热前500步从1e-6线性增加到目标学习率梯度裁剪限制梯度范数在1.0以内防止爆炸EMA模型使用指数移动平均保存更稳定的权重optimizer AdamW(model.parameters(), lr2e-4) scheduler get_cosine_schedule_with_warmup(optimizer, 500, 20000) for epoch in range(epochs): for x, _ in train_loader: x x.to(device) t torch.randint(0, timesteps, (x.shape[0],), devicedevice) optimizer.zero_grad() loss p_losses(model, x, t) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() # Update EMA ema.update(model.parameters())4. 采样与结果分析4.1 渐进式生成过程采样时我们从纯噪声开始逐步去噪torch.no_grad() def sample(model, shape): x torch.randn(shape, devicedevice) for t in reversed(range(timesteps)): t_tensor torch.full((shape[0],), t, devicedevice) pred_noise model(x, t_tensor) x scheduler.step(x, t_tensor, pred_noise) return x观察不同时间步的生成效果初期t800图像呈现模糊的色块中期400t≤800开始出现基本形状和颜色分布后期t≤400细节逐渐清晰生成质量显著提升4.2 常见问题排查在CIFAR-10上训练DDPM时你可能遇到以下问题生成图像模糊尝试增加模型容量或延长训练时间颜色偏差检查数据标准化是否一致训练不稳定降低学习率或使用梯度裁剪我在实际训练中发现使用以下参数组合效果较好{ timesteps: 1000, batch_size: 128, lr: 2e-4, ema_rate: 0.9999, unet_channels: [64, 128, 256], attention_resolutions: [16] }经过约50小时的训练单卡RTX 3090模型能够生成具有清晰特征的CIFAR-10类图像。虽然32x32的分辨率限制了细节表现但生成的物体轮廓和颜色分布已经相当准确。

相关新闻