别再死磕公式了!用Python从零实现一个简易DDPM图像生成器(附完整代码)

发布时间:2026/5/30 1:18:38

别再死磕公式了!用Python从零实现一个简易DDPM图像生成器(附完整代码) 用Python从零实现一个简易DDPM图像生成器附完整代码在生成式AI领域扩散模型Diffusion Models正迅速成为继GAN之后的新宠。本文将带你用Python从零构建一个基于MNIST数据集的简易DDPMDenoising Diffusion Probabilistic Models图像生成器无需深入复杂的数学推导通过代码实现直观理解其核心机制。1. 环境准备与数据加载首先确保安装必要的库pip install torch torchvision matplotlib我们使用PyTorch框架和MNIST手写数字数据集import torch import torch.nn as nn from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt # 数据预处理 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # 加载MNIST数据集 train_dataset datasets.MNIST(root./data, trainTrue, downloadTrue, transformtransform) train_loader DataLoader(train_dataset, batch_size128, shuffleTrue)2. 前向加噪过程实现DDPM的核心是通过逐步加噪破坏原始图像我们定义加噪调度和加噪函数def linear_beta_schedule(timesteps): 线性调度生成beta值 scale 1000 / timesteps beta_start scale * 0.0001 beta_end scale * 0.02 return torch.linspace(beta_start, beta_end, timesteps) timesteps 1000 betas linear_beta_schedule(timesteps) alphas 1. - betas alphas_cumprod torch.cumprod(alphas, dim0) sqrt_alphas_cumprod torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod torch.sqrt(1. - alphas_cumprod) def forward_diffusion_sample(x_0, t, devicecpu): 对输入图像x_0在时间步t加噪 noise torch.randn_like(x_0) sqrt_alpha_cumprod_t sqrt_alphas_cumprod[t].to(device) sqrt_one_minus_alpha_cumprod_t sqrt_one_minus_alphas_cumprod[t].to(device) return sqrt_alpha_cumprod_t * x_0 sqrt_one_minus_alpha_cumprod_t * noise, noise可视化加噪过程def plot_noise_process(): 可视化加噪过程 sample_img, _ next(iter(train_loader)) sample_img sample_img[0].unsqueeze(0) plt.figure(figsize(15, 5)) for i, t in enumerate([0, 100, 300, 600, 999]): noisy_img, _ forward_diffusion_sample(sample_img, t) plt.subplot(1, 5, i1) plt.imshow(noisy_img.squeeze().numpy(), cmapgray) plt.title(ft{t}) plt.axis(off) plt.show() plot_noise_process()3. 构建U-Net去噪网络我们实现一个简化版的U-Net来预测噪声class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim, upFalse): super().__init__() self.time_mlp nn.Linear(time_emb_dim, out_ch) if up: self.conv nn.Conv2d(2*in_ch, out_ch, 3, padding1) self.transform nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) else: self.conv nn.Conv2d(in_ch, out_ch, 3, padding1) self.transform nn.Conv2d(out_ch, out_ch, 4, 2, 1) self.bnorm nn.BatchNorm2d(out_ch) self.relu nn.ReLU() def forward(self, x, t): h self.bnorm(self.relu(self.conv(x))) time_emb self.relu(self.time_mlp(t)) time_emb time_emb.view(time_emb.shape[0], time_emb.shape[1], 1, 1) h h time_emb return self.transform(h) class SimpleUnet(nn.Module): def __init__(self): super().__init__() image_channels 1 down_channels (64, 128, 256, 512) up_channels (512, 256, 128, 64) out_dim 1 time_emb_dim 32 # 时间嵌入 self.time_mlp nn.Sequential( nn.Linear(time_emb_dim, time_emb_dim), nn.ReLU() ) # 下采样 self.conv0 nn.Conv2d(image_channels, down_channels[0], 3, padding1) self.downs nn.ModuleList([ Block(down_channels[i], down_channels[i1], time_emb_dim) for i in range(len(down_channels)-1) ]) # 上采样 self.ups nn.ModuleList([ Block(up_channels[i], up_channels[i1], time_emb_dim, upTrue) for i in range(len(up_channels)-1) ]) self.output nn.Conv2d(up_channels[-1], out_dim, 1) def forward(self, x, timestep): t self.time_mlp(timestep) x self.conv0(x) residual_inputs [] for down in self.downs: x down(x, t) residual_inputs.append(x) for up in self.ups: residual_x residual_inputs.pop() x torch.cat((x, residual_x), dim1) x up(x, t) return self.output(x)4. 训练流程实现定义损失函数和训练循环def get_loss(model, x_0, t): 计算损失函数 x_noisy, noise forward_diffusion_sample(x_0, t, device) noise_pred model(x_noisy, t) return torch.nn.functional.l1_loss(noise, noise_pred) device cuda if torch.cuda.is_available() else cpu model SimpleUnet().to(device) optimizer torch.optim.Adam(model.parameters(), lr2e-4) def train_epoch(loader): 训练一个epoch model.train() total_loss 0 for step, (batch, _) in enumerate(loader): optimizer.zero_grad() batch batch.to(device) # 随机采样时间步 t torch.randint(0, timesteps, (batch.shape[0],), devicedevice).long() loss get_loss(model, batch, t) loss.backward() optimizer.step() total_loss loss.item() return total_loss / len(loader) # 训练循环 for epoch in range(20): loss train_epoch(train_loader) print(fEpoch {epoch} | Loss: {loss:.4f})5. 采样生成新图像实现反向去噪采样过程torch.no_grad() def sample_plot_image(): 采样生成图像 model.eval() img_size 28 img torch.randn((1, 1, img_size, img_size), devicedevice) plt.figure(figsize(15, 5)) for i, t in enumerate([999, 750, 500, 250, 0]): t torch.tensor([t], devicedevice) with torch.no_grad(): pred_noise model(img, t) alpha_t alphas[t].to(device) alpha_t_cumprod alphas_cumprod[t].to(device) if t 0: noise torch.randn_like(img) else: noise torch.zeros_like(img) img (1 / torch.sqrt(alpha_t)) * ( img - ((1 - alpha_t) / torch.sqrt(1 - alpha_t_cumprod)) * pred_noise ) torch.sqrt(betas[t]) * noise if i % 200 0: plt.subplot(1, 5, i//200 1) plt.imshow(img.squeeze().cpu().numpy(), cmapgray) plt.title(ft{t.item()}) plt.axis(off) plt.show() sample_plot_image()6. 关键技巧与优化建议在实际应用中以下技巧可以提升DDPM性能学习率调度使用余弦退火等学习率调度策略EMA模型维护模型的指数移动平均版本用于最终推理混合精度训练使用AMP加速训练过程更复杂的网络结构尝试ResNet或Transformer backbone# 示例EMA模型实现 class EMA: def __init__(self, beta): super().__init__() self.beta beta self.step 0 def update_model_average(self, ema_model, current_model): for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): old_weight, new_weight ema_params.data, current_params.data ema_params.data self.update_average(old_weight, new_weight) def update_average(self, old, new): return old * self.beta (1 - self.beta) * new7. 完整代码整合将所有组件整合为可运行的完整实现# [此处整合前文所有代码片段] # 包含数据加载、模型定义、训练循环、采样函数等 # 确保添加必要的import和辅助函数运行这个完整脚本你将看到从随机噪声逐步生成手写数字的过程。虽然这是一个简化实现但它完整展示了DDPM的核心思想通过逐步去噪从随机噪声生成有意义的数据。

相关新闻