)
从零构建DDPMPython与PyTorch实战图像去噪在计算机视觉领域扩散模型正迅速成为生成高质量图像的主流方法。本文将带您从零开始使用PyTorch框架完整实现一个基础的Denoising Diffusion Probabilistic ModelDDPM无需深入复杂的数学推导通过代码直观理解这一强大模型的工作原理。1. 扩散模型基础概念扩散模型的核心思想是通过逐步添加噪声破坏图像再学习逆向去噪过程。想象一下把一杯清水慢慢滴入墨水的过程——扩散模型的正向过程就如同这个污染过程而逆向过程则是神奇的净化操作。与传统GAN或VAE不同DDPM具有几个独特优势训练稳定性不依赖对抗训练避免了模式坍塌问题生成质量逐步细化生成过程能产生更自然的高频细节理论优雅基于热力学的非平衡统计物理基础在技术实现层面DDPM主要包含两个关键阶段前向扩散过程Fixed Markov Chain逐步向数据添加高斯噪声逆向去噪过程Learned Transition训练神经网络逐步去噪# 基础配置 import torch import torch.nn as nn import numpy as np from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt device torch.device(cuda if torch.cuda.is_available() else cpu)2. 前向扩散过程实现前向过程定义为马尔可夫链逐步将数据转化为各向同性高斯分布。关键在于设计合理的噪声调度noise schedule控制不同时间步的噪声添加量。2.1 噪声调度设计我们采用线性噪声调度定义从β₁1e-4到β_T0.02的线性增长序列def linear_beta_schedule(timesteps, start1e-4, end0.02): return torch.linspace(start, end, timesteps) T 1000 # 总时间步数 betas linear_beta_schedule(T) alphas 1. - betas alphas_cumprod torch.cumprod(alphas, axis0) # α的连乘积 sqrt_alphas_cumprod torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod torch.sqrt(1. - alphas_cumprod)2.2 单步扩散实现给定原始图像x₀和时间步t计算加噪后的图像x_tdef q_sample(x_start, t, noiseNone): if noise is None: noise torch.randn_like(x_start) sqrt_alpha_cumprod_t sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1) sqrt_one_minus_alpha_cumprod_t sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1) return sqrt_alpha_cumprod_t * x_start sqrt_one_minus_alpha_cumprod_t * noise可视化不同时间步的加噪效果def plot_diffusion_process(image, num_steps5): plt.figure(figsize(15, 3)) plt.subplot(1, num_steps1, 1) plt.imshow(image.squeeze(), cmapgray) plt.title(Original) plt.axis(off) for i in range(1, num_steps1): t torch.tensor([i*(T//num_steps)-1]) noisy_image q_sample(image, t) plt.subplot(1, num_steps1, i1) plt.imshow(noisy_image.squeeze().cpu().numpy(), cmapgray) plt.title(fStep {t.item()1}) plt.axis(off) plt.show()3. 逆向去噪模型构建逆向过程的核心是训练一个噪声预测网络。我们采用改进的U-Net架构包含下采样和上采样路径并加入时间步嵌入。3.1 时间步嵌入将离散时间步转换为连续向量表示class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim dim def forward(self, t): device t.device half_dim self.dim // 2 embeddings torch.log(torch.tensor(10000.0)) / (half_dim - 1) embeddings torch.exp(torch.arange(half_dim, devicedevice) * -embeddings) embeddings t[:, None] * embeddings[None, :] embeddings torch.cat((embeddings.sin(), embeddings.cos()), dim-1) return embeddings3.2 基础残差块class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim): super().__init__() self.time_mlp nn.Linear(time_emb_dim, out_ch) self.conv1 nn.Conv2d(in_ch, out_ch, 3, padding1) self.conv2 nn.Conv2d(out_ch, out_ch, 3, padding1) self.act nn.SiLU() self.bn nn.BatchNorm2d(out_ch) def forward(self, x, t): h self.bn(self.act(self.conv1(x))) time_emb self.act(self.time_mlp(t)) h h time_emb.reshape(-1, h.shape[1], 1, 1) return self.act(self.conv2(h))3.3 完整U-Net实现class UNet(nn.Module): def __init__(self, in_channels1, out_channels1, dim32, dim_mults(1, 2, 4, 8)): super().__init__() self.time_mlp nn.Sequential( SinusoidalPositionEmbeddings(dim), nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim) ) dims [in_channels] [dim * m for m in dim_mults] self.downs nn.ModuleList([]) self.ups nn.ModuleList([]) # 下采样路径 for i in range(len(dims)-1): self.downs.append(Block(dims[i], dims[i1], dim)) # 中间层 self.mid Block(dims[-1], dims[-1], dim) # 上采样路径 for i in reversed(range(len(dims)-1)): self.ups.append(nn.ConvTranspose2d(dims[i1], dims[i], 4, 2, 1)) self.ups.append(Block(dims[i]*2, dims[i], dim)) self.final nn.Conv2d(dim, out_channels, 1) def forward(self, x, t): t self.time_mlp(t) hs [] # 下采样 for block in self.downs: x block(x, t) hs.append(x) x nn.functional.avg_pool2d(x, 2) # 中间层 x self.mid(x, t) # 上采样 for i in range(0, len(self.ups), 2): x self.ups[i](x) skip hs.pop() x torch.cat([x, skip], dim1) x self.ups[i1](x, t) return self.final(x)4. 训练流程实现DDPM的训练目标是最小化预测噪声与真实噪声之间的L2距离。4.1 损失函数定义def p_losses(denoise_model, x_start, t, noiseNone): if noise is None: noise torch.randn_like(x_start) x_noisy q_sample(x_start, t, noise) predicted_noise denoise_model(x_noisy, t) return torch.mean((noise - predicted_noise)**2)4.2 训练循环def train(model, dataloader, epochs100, lr1e-3): optimizer torch.optim.Adam(model.parameters(), lrlr) model.train() for epoch in range(epochs): total_loss 0 for batch, _ in dataloader: batch batch.to(device) # 随机采样时间步 t torch.randint(0, T, (batch.size(0),), devicedevice) # 计算损失 loss p_losses(model, batch, t) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss loss.item() print(fEpoch {epoch1} | Loss: {total_loss/len(dataloader):.4f}) return model5. 采样生成图像训练完成后我们可以通过逐步去噪从随机噪声生成新图像。5.1 单步采样torch.no_grad() def p_sample(model, x, t, t_index): betas_t betas[t].reshape(-1, 1, 1, 1) sqrt_one_minus_alphas_cumprod_t sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1) sqrt_recip_alphas_t torch.sqrt(1.0 / alphas[t]).reshape(-1, 1, 1, 1) # 预测噪声 pred_noise model(x, t) # 计算均值 model_mean sqrt_recip_alphas_t * (x - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t) if t_index 0: return model_mean else: posterior_variance_t (1 - alphas_cumprod[t-1]) / (1 - alphas_cumprod[t]) * betas[t] noise torch.randn_like(x) return model_mean torch.sqrt(posterior_variance_t).reshape(-1, 1, 1, 1) * noise5.2 完整采样流程torch.no_grad() def p_sample_loop(model, shape): # 从随机噪声开始 img torch.randn(shape, devicedevice) imgs [] for i in reversed(range(0, T)): t torch.full((shape[0],), i, devicedevice, dtypetorch.long) img p_sample(model, img, t, i) if i % (T//10) 0 or i T-1: imgs.append(img.cpu()) return imgs6. 实战演示与结果分析让我们在MNIST数据集上训练模型并观察生成效果。6.1 数据准备transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset datasets.MNIST(./data, trainTrue, downloadTrue, transformtransform) dataloader DataLoader(dataset, batch_size128, shuffleTrue)6.2 模型训练model UNet().to(device) trained_model train(model, dataloader, epochs20)6.3 生成新图像sample_size 16 generated_images p_sample_loop(trained_model, (sample_size, 1, 28, 28)) # 可视化生成过程 plt.figure(figsize(15, 15)) for i in range(len(generated_images)): plt.subplot(1, len(generated_images), i1) plt.imshow(generated_images[i][0].squeeze(), cmapgray) plt.title(fStep {i*(T//len(generated_images))}) plt.axis(off) plt.show()通过这个完整实现我们不仅理解了DDPM的核心原理还获得了可以实际运行的代码。虽然我们的示例基于简单的MNIST数据集但同样的架构经过适当调整可以扩展到更复杂的图像生成任务。