DiT 技术详解:把扩散模型的 U-Net 换成 Transformer,真正改变了什么

发布时间:2026/7/1 10:47:19

DiT 技术详解:把扩散模型的 U-Net 换成 Transformer,真正改变了什么 DiT 技术详解把扩散模型的 U-Net 换成 Transformer真正改变了什么如果只用一句话解释 DiT它把 latent diffusion 里的 U-Net 去噪网络换成了一个 ViT 风格的 Transformer。输入不再是卷积网络逐层处理的 feature map而是 VAE latent 被切成的 patch token扩散 timestep、类别标签等条件也不再靠 U-Net 里的 time embedding 到处注入而是通过 adaLN-Zero 调制 Transformer block。这句话听起来很像“架构替换”但 DiT 真正有意思的地方不在“Transformer 也能生成图像”。更关键的是它把图像扩散模型带到了一个更像 LLM/ViT 的缩放问题里模型宽度、深度、token 数、forward Gflops 和 FID 之间出现了很清楚的关系。论文最硬的一句话是DiT 的 Gflops 越高FID 越低而且这个趋势可以通过加深加宽 Transformer 或减小 patch size 来获得。本文按工程视角拆 DiT。先把它放回 latent diffusion 的 pipeline再看 patchify、条件注入、adaLN-Zero、缩放规律和代码实现。你不需要先背完整 DDPM 推导只要知道扩散模型训练一个网络ϵθ(xt,t,c)\epsilon_\theta(x_t,t,c)ϵθ​(xt​,t,c)去预测噪声DiT 改的是这个网络的 backbone。DiT 在扩散 pipeline 里替换的是哪一块Stable Diffusion 这类 latent diffusion model 通常有三块VAE encoder 把图像压成 latent去噪网络在 latent 空间里跑扩散反推VAE decoder 再把 latent 解回像素图。DiT 保留了这条路线只替换中间的去噪网络。以 256×256 RGB 图像为例论文使用 Stable Diffusion 的预训练 VAEdownsample factor 是 8。所以图像x∈R256×256×3x \in \mathbb{R}^{256 \times 256 \times 3}x∈R256×256×3会被编码成z∈R32×32×4z \in \mathbb{R}^{32 \times 32 \times 4}z∈R32×32×4。DiT 不是直接处理 256×256 像素而是处理这个 32×32×4 的 latent grid。这样做很重要因为如果直接在像素空间把图像切成 token序列长度和算力会马上爆掉。可以把 DiT 的位置画成这样image x │ ▼ VAE encoder E 冻结不训练 │ ▼ latent z0: 32×32×4 │ add noise at timestep t ▼ noised latent zt │ ▼ DiT backbone: patchify → Transformer blocks → unpatchify │ ▼ predicted noise / covariance │ sampling loop ▼ latent z0_hat │ ▼ VAE decoder D 冻结不训练 │ ▼ generated image这里有一个容易被忽略的点DiT 并没有提出新的扩散目标也没有换掉 classifier-free guidance。它沿用 ADM/LDM 里很成熟的训练和采样设定包括 learned covariance、250-step DDPM sampling、FID-50K 评估等。论文的实验设计其实很克制尽量少动 diffusion recipe把变量集中到 backbone 上。这也是为什么 DiT 的结论比较干净。它是在问如果把 U-Net 这个默认选择换成标准 Transformer扩散模型还能不能按 compute scaling 的方式变好答案是能而且趋势相当稳定。patchifylatent grid 怎么变成 token 序列DiT 继承 ViT 的第一步patchify。输入 latent 的形状是I×I×CI \times I \times CI×I×Cpatch size 是p×pp \times pp×p那么 token 数是T(I/p)2 T (I / p)^2T(I/p)2每个 patch 被线性投影到 hidden dimensionddd再加上固定的二维 sine-cosine positional embedding。对于 256×256 图像latent spatial size 是I32I32I32。如果使用 DiT-XL/2patch sizep2p2p2token 数就是16×1625616 \times 16 25616×16256。如果是 DiT-XL/4token 数降到8×8648 \times 8 648×864。如果是 DiT-XL/8就只有4×4164 \times 4 164×416个 token。这给 DiT 带来一个很直接的旋钮减小 patch size 会增加 token 数也会显著增加 Transformer 的计算量。论文里说得很明确patch size 减半会让 token 数变成四倍因此 Transformer Gflops 至少变成四倍。更微妙的是减小 patch size 几乎不增加参数量因为参数主要在 Transformer block 的权重里不在 token 数里。这点和普通 CNN scaling 不太一样。你可以在参数量几乎不变的情况下通过让模型处理更多 token 来提高 forward compute。DiT 的实验显示这种 compute 增加确实能改善 FID。换句话说DiT 的质量不只由参数量决定也由“每次去噪到底看了多少 token、做了多少 attention/MLP 计算”决定。官方 PyTorch 代码里对应的是这一行xself.x_embedder(x)self.pos_embed# (N, T, D)x_embedder来自 timm 的PatchEmbed。后面所有 DiT block 都处理这个 token sequence。最后再把每个 token 解码成p×p×2Cp \times p \times 2Cp×p×2C其中2C2C2C是因为模型同时预测噪声和 diagonal covariance。DiT block标准 Transformer但条件注入不能随便做把 latent patch 变成 token 以后最自然的想法是直接套 ViT blockLayerNorm、self-attention、MLP、residual。问题在于扩散模型不是普通图像分类。去噪网络每一步都需要知道 timesteptttclass-conditional ImageNet 还需要类别标签ccc。这些条件怎么进入 Transformer block会明显影响效果。DiT 论文比较了四种做法条件注入方式做法额外计算论文里的结论In-context conditioning把 timestep 和 class embedding 当作额外 token 拼进序列很小简单但效果较差Cross-attention图像 token self-attention 后再对条件 token 做 cross-attention最高约 15% overhead计算更贵但不占优adaLN用条件向量生成 LayerNorm 的 scale/shift很小比前两者更高效adaLN-Zero在 adaLN 上加 residual gate并把 gate 初始化为 0很小最好后续实验默认使用这个结果有点反直觉。很多 text-to-image 模型里 cross-attention 是核心部件所以容易下意识觉得 cross-attention 更强。但在 DiT 的 ImageNet class-conditional 设置里条件只是 timestep 和 class label信息量很小。为这种短条件专门加 cross-attention不一定划算。adaLN 的思路更像 FiLM先把 timestep embedding 和 label embedding 相加得到条件向量ccc再由一个 MLP 生成每个 block 的调制参数。官方实现里每个 DiTBlock 的调制层输出 6 组向量shift_msa,scale_msa,gate_msa,shift_mlp,scale_mlp,gate_mlp\ self.adaLN_modulation(c).chunk(6,dim1)然后 attention branch 和 MLP branch 分别这样走xxgate_msa.unsqueeze(1)*self.attn(modulate(self.norm1(x),shift_msa,scale_msa))xxgate_mlp.unsqueeze(1)*self.mlp(modulate(self.norm2(x),shift_mlp,scale_mlp))modulate很简单defmodulate(x,shift,scale):returnx*(1scale.unsqueeze(1))shift.unsqueeze(1)也就是说条件向量不作为 token 参加 attention而是改变每个 block 里 normalization 后的表示并通过 gate 控制 residual branch 的强度。adaLN-Zero 为什么是 DiT 的关键小改动adaLN-Zero 的“Zero”不是名字装饰。它把每个 DiT block 初始成接近 identity function。具体做法是让adaLN_modulation最后一层线性层的权重和 bias 初始化为 0于是初始时 shift、scale、gate 都是 0。残差分支一开始被 gate 关掉整个 block 更像恒等映射。官方代码里可以直接看到forblockinself.blocks:nn.init.constant_(block.adaLN_modulation[-1].weight,0)nn.init.constant_(block.adaLN_modulation[-1].bias,0)nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight,0)nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias,0)nn.init.constant_(self.final_layer.linear.weight,0)nn.init.constant_(self.final_layer.linear.bias,0)这和一些 ResNet/扩散 U-Net 的初始化习惯一致残差块刚开始不要太激进先让网络从稳定的近似恒等映射开始学。扩散模型的训练本来就要处理不同 noise level 下的输入如果一开始每个 block 都强烈扰动 token训练会更难。论文的消融很有说服力。四种 block 设计都在 DiT-XL/2 上比较Gflops 大致接近in-context 119.4cross-attention 137.6adaLN 118.6adaLN-Zero 118.6。结果是 adaLN-Zero 在训练过程中 FID 最低400K steps 时几乎把 in-context 的 FID 降到一半。这个结论说明 DiT 不是“把 ViT 塞进 diffusion 就行”。条件注入和初始化是成败点。从工程实现看adaLN-Zero 还有一个优点它不像 cross-attention 那样引入额外条件序列也不强依赖复杂的注意力 mask。对于 class-conditioned 或 timestep-conditioned 模型这种调制式条件注入非常干净。DiT-S/B/L/XL 和 /2 /4 /8 到底怎么命名DiT 的模型名有两部分前面的 S/B/L/XL 表示 Transformer 主体大小后面的/2、/4、/8表示 latent patch size。论文的主配置如下模型层数 Nhidden size dheads在 I32, p4 时的 GflopsDiT-S1238461.4DiT-B12768125.6DiT-L2410241619.7DiT-XL2811521629.1DiT-XL/2就是 28 层、hidden size 1152、16 heads、patch size 2。对于 256×256 图像它处理 32×32×4 latentpatch size 2 产生 256 个 token。对于 512×512 图像latent 是 64×64×4patch size 2 产生 1024 个 token所以 512 分辨率下的 DiT-XL/2 Gflops 会明显更高。论文里最值得记的不是某个具体配置而是两条 scaling 路线固定 patch size加深加宽 TransformerS → B → L → XL。固定模型大小减小 patch size/8 → /4 → /2。两条路都会增加 Gflops也都会改善 FID。更关键的是参数量不是唯一解释变量。比如固定 DiT-XL只把 patch size 从 4 改成 2参数量几乎不变但 token 数和 Gflops 大幅增加FID 仍然显著变好。这就是 DiT 对后续生成模型架构的影响图像生成不必永远围绕 U-Net 设计。只要把输入组织成 token并找到合适的条件注入方式扩散模型也能进入 Transformer 的 scaling 逻辑。训练设置DiT 的 recipe 其实很保守DiT 的实验是在 ImageNet class-conditional generation 上做的分辨率包括 256×256 和 512×512。训练时使用 AdamWbatch size 256学习率1×10−41 \times 10^{-4}1×10−4没有 weight decay只用 horizontal flip 作为数据增强。论文还提到他们没有发现 ViT 训练里常见的 warmup 或正则化是必需的训练过程稳定没有观察到 Transformer 训练中常见的 loss spike。扩散部分基本沿用 ADM1000-step linear variance schedule预测噪声和 learned covariance采样评估用 250 DDPM steps。评估用 FID-50K并用 ADM 的 TensorFlow evaluation suite 来保证和 prior work 可比。最大模型的训练成本不低。论文报告 DiT-XL/2 在 TPU v3-256 pod 上训练速度约 5.7 iterations/secondglobal batch size 256。官方 PyTorch repo 后来提供了 DDP 训练脚本也说明用 8×A100 训练 DiT-XL/2、用 4×A100 训练 DiT-B/4可以在数十万步范围内复现 JAX 结果到合理随机波动内。如果只是跑预训练模型官方 repo 给出的最小命令很简单gitclone https://github.com/facebookresearch/DiT.gitcdDiT condaenvcreate-fenvironment.yml conda activate DiT python sample.py --image-size512--seed1训练自己的 class-conditional DiTtorchrun--nnodes1--nproc_per_nodeN train.py\--modelDiT-XL/2\--data-path /path/to/imagenet/train如果要严肃复现实验最容易踩坑的是 FID 评估而不是模型 forward。FID 对 resize、采样数量、VAE decoder、guidance scale 都敏感。官方 README 里强调PyTorch 训练结果表里的 FID 是 250 DDPM sampling steps、mseVAE decoder、无 guidancecfg-scale1条件下算的。实验结果该怎么看不是“Transformer 赢了”而是“compute scaling 很干净”DiT-XL/2 在 256×256 ImageNet 上使用 classifier-free guidance scale 1.50 时 FID-50K 达到 2.27超过论文比较中的 prior diffusion models。512×512 上DiT-XL/2-G 达到 3.04 FID也优于当时对比的 ADM、ADM-U、ADM-G 等结果。但我觉得这篇论文真正有价值的结果不是 SOTA 表格而是 Figure 6、Figure 8、Figure 9 那组 scaling 分析。它们回答了三个更底层的问题第一增加 Transformer depth/width 有用吗有。固定 patch size 时从 S/B/L 到 XLFID 在训练各阶段都改善。第二增加 token 数有用吗也有。固定模型大小时把 patch size 从 8 降到 4、再降到 2FID 同样持续改善。第三小模型多采样几步能不能补回来很难。论文比较了不同 sampling steps 下的 FID发现增加 sampling compute 不能弥补 backbone compute 不足。比如 DiT-L/2 用 1000 sampling steps 时每张图采样计算量约 80.7 TflopsDiT-XL/2 用 128 steps 只用约 15.2 Tflops但 FID-10K 仍然更好23.7 vs 25.9。这对实际训练很有启发。扩散模型的质量不是只靠“采样时多跑几步”堆出来的。backbone 本身的容量和每步 forward compute 仍然很关键。对于 DiT训练一个足够大的模型可能比在小模型上用更重的 sampler 更划算。从代码看一次 forward把官方models.py抽象一下DiT forward 主要是五步defforward(self,x,t,y):# 1. latent patches - token sequencexself.x_embedder(x)self.pos_embed# 2. timestep / label embeddingstself.t_embedder(t)yself.y_embedder(y,self.training)cty# 3. Transformer blocks with adaLN-Zeroforblockinself.blocks:xblock(x,c)# 4. decode each token to patch predictionxself.final_layer(x,c)# 5. token sequence - latent gridxself.unpatchify(x)returnx这里的x不是像素图而是 noisy latent。t是 diffusion timestep。y是类别标签训练时会按class_dropout_prob随机 drop 成 null label用来支持 classifier-free guidance。官方forward_with_cfg还有一个实现细节为了可复现默认只对前三个 channel 应用 classifier-free guidance而不是对所有 output channel。代码注释说标准做法可以改成对所有 channels 做 CFG。这种细节如果不注意复现出来的采样结果可能和 README 或论文不一致。另一个实现细节是位置编码。DiT 使用固定二维 sin-cos positional embeddingrequires_gradFalse。这和 ViT/MAE 的习惯一致也让模型结构更简单。DiT 本质上把 latent grid 当成一张“低分辨率图像”所以二维位置编码比一维 learnable embedding 更自然。DiT 和 U-Net 的差别不只是有没有卷积U-Net 的强项是局部归纳偏置和多尺度结构。高分辨率图像生成里U-Net 通过 downsample/upsample 路径在不同空间尺度处理特征skip connection 又保留细节。这个设计很适合图像。DiT 的强项是统一的 token 表示和清晰的 scaling 规则。它没有显式金字塔也没有 U-Net 的多尺度 skip。所有 token 在同一 hidden dimension 里反复经过 attention 和 MLP。局部性不是结构硬编码出来的而更多来自 latent patch、位置编码和训练数据。这不是说 DiT 在所有场景都天然优于 U-Net。原始 DiT 的 ImageNet class-conditional 设置比较干净条件也很短。如果换成 text-to-image条件变成长文本cross-attention 或更复杂的 multimodal attention 又会回来。后来的 MMDiT、PixArt、Stable Diffusion 3 等路线本质上都是在 DiT/Transformer backbone 上重新设计文本条件、训练效率和高分辨率生成。所以更准确的判断是DiT 证明了扩散模型不需要永远依赖 U-Net inductive bias。Transformer backbone 可以在 latent diffusion 里工作而且可以按 compute 规律稳定变好。但具体到 text-to-image、video、3D 或 controllable generation条件组织和训练 recipe 仍然决定上限。实践里什么时候该考虑 DiT如果你在做 image/video generation 或多模态生成模型DiT 值得考虑的场景通常有几类。第一模型规模会继续变大。Transformer 的工程生态更成熟FlashAttention、sequence parallel、tensor parallel、checkpointing、fused MLP、KV/attention 优化这些都更容易迁移到 DiT 类架构上。U-Net 当然也能优化但 Transformer scaling 的工具链更完整。第二你的输入和条件天然是 token。比如文本、动作、语音、相机轨迹、agent state、layout token、视频 patch。如果所有东西都能变成 token那么 Transformer backbone 的统一接口会很舒服。相反如果任务强依赖局部纹理和多尺度 skipU-Net 仍然可能更省算力。第三你关心 scaling law 式的实验设计。DiT 给了一个很清楚的坐标系模型大小、patch size、token 数、Gflops、训练步数、FID。你可以系统扫配置而不是只在 U-Net channel multiplier、attention resolution、resblock 数量里调参。实际落地时我会先看三个约束约束更偏 DiT更偏 U-Net数据/算力有足够训练 compute计划 scale算力有限需要强 inductive bias条件形式多模态 token、长上下文、需要统一建模条件简单局部控制为主工程目标想复用 Transformer 优化栈想用成熟 diffusion U-Net 生态DiT 不是免费午餐。attention 对 token 数敏感patch size 一小计算量马上上去。512×512 下 DiT-XL/2 处理 1024 个 latent tokensGflops 达到 524.6。更高分辨率或视频任务如果直接照搬会遇到序列长度问题。因此后续工作经常会引入更高效的 attention、factorized attention、latent compression 或分层结构。一个简化版 DiT mental model如果你想快速在脑子里跑一遍 DiT可以用下面这个 mental model1. VAE 把图像压成小 latent map。 2. diffusion 给 latent 加噪声得到 zt。 3. DiT 把 zt 切成 patch tokens。 4. timestep label 变成一个条件向量 c。 5. 每个 Transformer block 用 c 生成 adaLN 的 shift/scale/gate。 6. token 经过 self-attention 和 MLP预测噪声与方差。 7. 采样循环反复调用 DiT把纯噪声 latent 还原成可解码 latent。 8. VAE decoder 把 latent 变回图像。如果再压缩一点DiT LDM 的 latent 空间 ViT patch tokens adaLN-Zero 条件注入 compute scaling。这个公式比“U-Net 换 Transformer”更有信息量。因为只换 Transformer 不够必须同时解释 latent patch 怎么组织、条件怎么进 block、为什么 zero init 稳定、以及质量为什么跟 Gflops 强相关。局限和阅读建议DiT 原论文的实验场景是 ImageNet class-conditional generation不是今天更常见的开放词表 text-to-image。它证明了 Transformer backbone 在扩散图像生成里可行也给了清晰的缩放证据但没有解决文本对齐、复杂 prompt following、超高分辨率生成或视频长程一致性。读这篇论文时建议不要只盯着最终 FID。更值得反复看的有三处Figure 3 的 block designFigure 6/8 的 scaling 曲线以及官方models.py里的 adaLN-Zero 实现。看完这三处DiT 的核心基本就通了。后续如果继续读可以沿两条线走。一条是架构线PixArt、MMDiT、Stable Diffusion 3 这类模型如何把文本条件和 DiT backbone 结合起来。另一条是效率线FlashAttention、sequence length reduction、latent tokenization、video DiT 如何处理更长序列。DiT 本身是起点不是终点。参考资料William Peebles, Saining Xie, “Scalable Diffusion Models with Transformers”, arXiv:2212.09748 / ICCV 2023, retrieved 2026-06-30, https://arxiv.org/abs/2212.09748ICCV 2023 open access paper PDF, retrieved 2026-06-30, https://openaccess.thecvf.com/content/ICCV2023/papers/Peebles_Scalable_Diffusion_Models_with_Transformers_ICCV_2023_paper.pdfDiT official project page, retrieved 2026-06-30, https://www.wpeebles.com/DiTfacebookresearch/DiT official PyTorch implementation, retrieved 2026-06-30, https://github.com/facebookresearch/DiTfacebookresearch/DiTmodels.py, retrieved 2026-06-30, https://github.com/facebookresearch/DiT/blob/main/models.py

相关新闻